Classifier.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
  2. svm_save_model, svm_load_model
  3. class Classifier:
  4. def __init__(self, c=None, filename=None):
  5. if filename:
  6. # If a filename is given, load a model from the fiven filename
  7. self.model = svm_load_model(filename)
  8. else:
  9. self.param = svm_parameter()
  10. self.param.kernel_type = 2 # Radial kernel type
  11. self.param.C = c
  12. self.model = None
  13. def save(self, filename):
  14. """Save the SVM model in the given filename."""
  15. svm_save_model(filename, self.model)
  16. def train(self, learning_set):
  17. """Train the classifier with a list of character objects that have
  18. known values."""
  19. classes = []
  20. features = []
  21. l = len(learning_set)
  22. for i, char in enumerate(learning_set):
  23. print 'Training "%s" -- %d of %d (%d%% done)' \
  24. % (char.value, i + 1, l, int(100 * (i + 1) / l))
  25. classes.append(float(ord(char.value)))
  26. features.append(char.get_feature_vector())
  27. problem = svm_problem(classes, features)
  28. self.model = svm_train(problem, self.param)
  29. def classify(self, character):
  30. """Classify a character object, return its value."""
  31. predict = lambda x: svm_predict([0], [x], self.model)[0][0]
  32. prediction_class = predict(character.get_feature_vector())
  33. return chr(int(prediction_class))