Classifier.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
  2. svm_save_model, svm_load_model, RBF
  3. class Classifier:
  4. def __init__(self, c=None, gamma=None, filename=None, neighbours=3, \
  5. verbose=0):
  6. self.neighbours = neighbours
  7. self.verbose = verbose
  8. if filename:
  9. # If a filename is given, load a model from the given filename
  10. if verbose:
  11. print 'Loading classifier from "%s"...' % filename
  12. self.model = svm_load_model(filename)
  13. elif c == None or gamma == None:
  14. raise Exception('Please specify both C and gamma.')
  15. else:
  16. self.param = svm_parameter()
  17. self.param.C = c # Soft margin
  18. self.param.kernel_type = RBF # Radial kernel type
  19. self.param.gamma = gamma # Parameter for radial kernel
  20. self.model = None
  21. def save(self, filename):
  22. """Save the SVM model in the given filename."""
  23. if self.verbose:
  24. print 'Saving classifier in "%s"...' % filename
  25. svm_save_model(filename, self.model)
  26. def train(self, learning_set):
  27. """Train the classifier with a list of character objects that have
  28. known values."""
  29. classes = []
  30. features = []
  31. l = len(learning_set)
  32. for i, char in enumerate(learning_set):
  33. if self.verbose:
  34. print 'Found "%s" -- %d of %d (%d%% done)' \
  35. % (char.value, i + 1, l, round(100 * (i + 1) / l))
  36. classes.append(float(ord(char.value)))
  37. #features.append(char.get_feature_vector())
  38. char.get_single_cell_feature_vector(self.neighbours)
  39. features.append(char.feature)
  40. problem = svm_problem(classes, features)
  41. self.model = svm_train(problem, self.param)
  42. def test(self, test_set):
  43. """Test the classifier with the given test set and return the score."""
  44. matches = 0
  45. for char in test_set:
  46. prediction = self.classify(char)
  47. if char.value == prediction:
  48. matches += 1
  49. return float(matches) / len(test_set)
  50. def classify(self, character, true_value=None):
  51. """Classify a character object, return its value."""
  52. true_value = 0 if true_value == None else ord(true_value)
  53. #x = character.get_feature_vector(self.cell_size)
  54. character.get_single_cell_feature_vector(self.neighbours)
  55. #p = svm_predict([true_value], [character.feature], self.model, '-b 1')
  56. p = svm_predict([true_value], [character.feature], self.model)
  57. prediction_class = int(p[0][0])
  58. return chr(prediction_class)