Classifier.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from svmutil import svm_model, svm_problem, svm_parameter, svm_predict, LINEAR
  2. from cPicle import dump, load
  3. class Classifier:
  4. def __init__(self, c=None, filename=None):
  5. if filename:
  6. # If a filename is given, load a modl from the fiven filename
  7. f = file(filename, 'r')
  8. self.model, self.param, self.character_map = load(f)
  9. f.close()
  10. else:
  11. self.param = svm_parameter()
  12. self.param.kernel_type = LINEAR
  13. self.param.C = c
  14. self.character_map = {}
  15. self.model = None
  16. def save(self, filename):
  17. """Save the SVM model in the given filename."""
  18. f = file(filename, 'w+')
  19. dump((self.model, self.param, self.character_map), f)
  20. f.close()
  21. def train(self, learning_set):
  22. """Train the classifier with a list of character objects that have
  23. known values."""
  24. classes = []
  25. features = []
  26. for char in learning_set:
  27. # Map the character to an integer for use in the SVM model
  28. if char.value not in self.character_map:
  29. self.character_map[char.value] = len(self.character_map)
  30. classes.append(self.character_map[char.value])
  31. features.append(char.get_feature_vector())
  32. problem = svm_problem(self.c, features)
  33. self.model = svm_model(problem, self.param)
  34. # Add prediction function that returns a numeric class prediction
  35. self.model.predict = lambda self, x: svm_predict([0], [x], self)[0][0]
  36. def classify(self, character):
  37. """Classify a character object and assign its value."""
  38. prediction = self.model.predict(character.get_feature_vector())
  39. for value, svm_class in self.character_map.iteritems():
  40. if svm_class == prediction:
  41. return value