Classifier.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
  2. svm_save_model, svm_load_model, RBF
  3. class Classifier:
  4. def __init__(self, c=None, gamma=None, filename=None, cell_size=12):
  5. self.cell_size = cell_size
  6. if filename:
  7. # If a filename is given, load a model from the given filename
  8. self.model = svm_load_model(filename)
  9. elif c == None or gamma == None:
  10. raise Exception('Please specify both C and gamma.')
  11. else:
  12. self.param = svm_parameter()
  13. self.param.C = c # Soft margin
  14. self.param.kernel_type = RBF # Radial kernel type
  15. self.param.gamma = gamma # Parameter for radial kernel
  16. self.model = None
  17. def save(self, filename):
  18. """Save the SVM model in the given filename."""
  19. svm_save_model(filename, self.model)
  20. def train(self, learning_set):
  21. """Train the classifier with a list of character objects that have
  22. known values."""
  23. classes = []
  24. features = []
  25. l = len(learning_set)
  26. for i, char in enumerate(learning_set):
  27. print 'Found "%s" -- %d of %d (%d%% done)' \
  28. % (char.value, i + 1, l, int(100 * (i + 1) / l))
  29. classes.append(float(ord(char.value)))
  30. #features.append(char.get_feature_vector())
  31. char.get_single_cell_feature_vector()
  32. features.append(char.feature)
  33. problem = svm_problem(classes, features)
  34. self.model = svm_train(problem, self.param)
  35. def test(self, test_set):
  36. """Test the classifier with the given test set and return the score."""
  37. matches = 0
  38. for char in test_set:
  39. prediction = self.classify(char)
  40. if char.value == prediction:
  41. matches += 1
  42. return float(matches) / len(test_set)
  43. def classify(self, character, true_value=None):
  44. """Classify a character object, return its value."""
  45. true_value = 0 if true_value == None else ord(true_value)
  46. #x = character.get_feature_vector(self.cell_size)
  47. character.get_single_cell_feature_vector()
  48. p = svm_predict([true_value], [character.feature], self.model)
  49. prediction_class = int(p[0][0])
  50. return chr(prediction_class)