Classifier.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
  2. svm_save_model, svm_load_model, RBF
  3. class Classifier:
  4. def __init__(self, c=None, gamma=None, filename=None, neighbours=3, \
  5. verbose=0):
  6. if filename:
  7. # If a filename is given, load a model from the given filename
  8. self.model = svm_load_model(filename)
  9. elif c == None or gamma == None:
  10. raise Exception('Please specify both C and gamma.')
  11. else:
  12. self.param = svm_parameter()
  13. self.param.C = c # Soft margin
  14. self.param.kernel_type = RBF # Radial kernel type
  15. self.param.gamma = gamma # Parameter for radial kernel
  16. self.model = None
  17. self.neighbours = neighbours
  18. self.verbose = verbose
  19. def save(self, filename):
  20. """Save the SVM model in the given filename."""
  21. svm_save_model(filename, self.model)
  22. def train(self, learning_set):
  23. """Train the classifier with a list of character objects that have
  24. known values."""
  25. classes = []
  26. features = []
  27. l = len(learning_set)
  28. for i, char in enumerate(learning_set):
  29. if self.verbose:
  30. print 'Found "%s" -- %d of %d (%d%% done)' \
  31. % (char.value, i + 1, l, round(100 * (i + 1) / l))
  32. classes.append(float(ord(char.value)))
  33. #features.append(char.get_feature_vector())
  34. char.get_single_cell_feature_vector(self.neighbours)
  35. features.append(char.feature)
  36. problem = svm_problem(classes, features)
  37. self.model = svm_train(problem, self.param)
  38. def test(self, test_set):
  39. """Test the classifier with the given test set and return the score."""
  40. matches = 0
  41. for char in test_set:
  42. prediction = self.classify(char)
  43. if char.value == prediction:
  44. matches += 1
  45. return float(matches) / len(test_set)
  46. def classify(self, character, true_value=None):
  47. """Classify a character object, return its value."""
  48. true_value = 0 if true_value == None else ord(true_value)
  49. #x = character.get_feature_vector(self.cell_size)
  50. character.get_single_cell_feature_vector(self.neighbours)
  51. #p = svm_predict([true_value], [character.feature], self.model, '-b 1')
  52. p = svm_predict([true_value], [character.feature], self.model)
  53. prediction_class = int(p[0][0])
  54. return chr(prediction_class)