ソースを参照

Merge branch 'master' of github.com:taddeus/licenseplates

Jayke Meijer 14 年 前
コミット
24b7b1d7d2
2 ファイル変更14 行追加26 行削除
  1. 9 23
      src/Classifier.py
  2. 5 3
      src/ClassifierTest.py

+ 9 - 23
src/Classifier.py

@@ -1,29 +1,21 @@
 from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
         svm_save_model, svm_load_model
-from cPickle import dump, load
 
 
 class Classifier:
     def __init__(self, c=None, filename=None):
         if filename:
-            # If a filename is given, load a modl from the fiven filename
-            self.model = svm_load_model(filename + '-model')
-            f = file(filename + '-characters', 'r')
-            self.character_map = load(f)
-            f.close()
+            # If a filename is given, load a model from the given filename
+            self.model = svm_load_model(filename)
         else:
             self.param = svm_parameter()
-            self.param.kernel_type = 2
+            self.param.kernel_type = 2  # Radial kernel type
             self.param.C = c
-            self.character_map = {}
             self.model = None
 
     def save(self, filename):
         """Save the SVM model in the given filename."""
-        svm_save_model(filename + '-model', self.model)
-        f = file(filename + '-characters', 'w+')
-        dump(self.character_map, f)
-        f.close()
+        svm_save_model(filename, self.model)
 
     def train(self, learning_set):
         """Train the classifier with a list of character objects that have
@@ -34,22 +26,16 @@ class Classifier:
 
         for i, char in enumerate(learning_set):
             print 'Training "%s"  --  %d of %d (%d%% done)' \
-                    % (char.value, i + 1, l, int(100 * (i + 1) / l))
-            # Map the character to an integer for use in the SVM model
-            if char.value not in self.character_map:
-                self.character_map[char.value] = len(self.character_map)
-
-            classes.append(self.character_map[char.value])
+                  % (char.value, i + 1, l, int(100 * (i + 1) / l))
+            classes.append(float(ord(char.value)))
             features.append(char.get_feature_vector())
 
         problem = svm_problem(classes, features)
         self.model = svm_train(problem, self.param)
 
     def classify(self, character):
-        """Classify a character object and assign its value."""
+        """Classify a character object, return its value."""
         predict = lambda x: svm_predict([0], [x], self.model)[0][0]
-        prediction = predict(character.get_feature_vector())
+        prediction_class = predict(character.get_feature_vector())
 
-        for value, svm_class in self.character_map.iteritems():
-            if svm_class == prediction:
-                return value
+        return chr(int(prediction_class))

+ 5 - 3
src/ClassifierTest.py

@@ -21,18 +21,20 @@ print 'loaded %d chars' % len(chars)
 
 dump(chars, file('chars', 'w+'))
 #----------------------------------------------------------------
-chars = load(file('chars', 'r'))
+chars = load(file('chars', 'r'))[:500]
 learned = []
 learning_set = []
 test_set = []
 
 for char in chars:
-    if learned.count(char.value) > 80:
+    if learned.count(char.value) > 12:
         test_set.append(char)
     else:
         learning_set.append(char)
         learned.append(char.value)
 
+#print 'Learning set:', [c.value for c in learning_set]
+#print 'Test set:', [c.value for c in test_set]
 dump(learning_set, file('learning_set', 'w+'))
 dump(test_set, file('test_set', 'w+'))
 #----------------------------------------------------------------
@@ -52,7 +54,7 @@ for i, char in enumerate(test_set):
     prediction = classifier.classify(char)
 
     if char.value == prediction:
-        print ':) ------> Successfully recognized "%s"' % char.value,
+        print ':-----> Successfully recognized "%s"' % char.value,
         matches += 1
     else:
         print ':( Expected character "%s", got "%s"' \