Преглед изворни кода

Replaced character map with char-to-int cast in SVM trainer and shrunk learning- and testsets.

Taddeus Kroes пре 14 година
родитељ
комит
b9bc273e42
2 измењених фајлова са 27 додато и 38 уклоњено
  1. 5 18
      src/Classifier.py
  2. 22 20
      src/ClassifierTest.py

+ 5 - 18
src/Classifier.py

@@ -7,23 +7,16 @@ class Classifier:
     def __init__(self, c=None, filename=None):
         if filename:
             # If a filename is given, load a modl from the fiven filename
-            self.model = svm_load_model(filename + '-model')
-            f = file(filename + '-characters', 'r')
-            self.character_map = load(f)
-            f.close()
+            self.model = svm_load_model(filename)
         else:
             self.param = svm_parameter()
             self.param.kernel_type = 2
             self.param.C = c
-            self.character_map = {}
             self.model = None
 
     def save(self, filename):
         """Save the SVM model in the given filename."""
-        svm_save_model(filename + '-model', self.model)
-        f = file(filename + '-characters', 'w+')
-        dump(self.character_map, f)
-        f.close()
+        svm_save_model(filename, self.model)
 
     def train(self, learning_set):
         """Train the classifier with a list of character objects that have
@@ -35,11 +28,7 @@ class Classifier:
         for i, char in enumerate(learning_set):
             print 'Training "%s"  --  %d of %d (%d%% done)' \
                     % (char.value, i + 1, l, int(100 * (i + 1) / l))
-            # Map the character to an integer for use in the SVM model
-            if char.value not in self.character_map:
-                self.character_map[char.value] = len(self.character_map)
-
-            classes.append(self.character_map[char.value])
+            classes.append(float(ord(char.value)))
             features.append(char.get_feature_vector())
 
         problem = svm_problem(classes, features)
@@ -48,8 +37,6 @@ class Classifier:
     def classify(self, character):
         """Classify a character object and assign its value."""
         predict = lambda x: svm_predict([0], [x], self.model)[0][0]
-        prediction = predict(character.get_feature_vector())
+        prediction_class = predict(character.get_feature_vector())
 
-        for value, svm_class in self.character_map.iteritems():
-            if svm_class == prediction:
-                return value
+        return chr(int(prediction_class))

+ 22 - 20
src/ClassifierTest.py

@@ -3,36 +3,38 @@ from LicensePlate import LicensePlate
 from Classifier import Classifier
 from cPickle import dump, load
 
-chars = []
-
-for i in range(9):
-    for j in range(100):
-        try:
-            filename = '%04d/00991_%04d%02d.info' % (i, i, j)
-            print 'loading file "%s"' % filename
-            plate = LicensePlate(i, j)
-
-            if hasattr(plate, 'characters'):
-                chars.extend(plate.characters)
-        except:
-            print 'epic fail'
-
-print 'loaded %d chars' % len(chars)
-
-dump(chars, file('chars', 'w+'))
+#chars = []
+#
+#for i in range(9):
+#    for j in range(100):
+#        try:
+#            filename = '%04d/00991_%04d%02d.info' % (i, i, j)
+#            print 'loading file "%s"' % filename
+#            plate = LicensePlate(i, j)
+#
+#            if hasattr(plate, 'characters'):
+#                chars.extend(plate.characters)
+#        except:
+#            print 'epic fail'
+#
+#print 'loaded %d chars' % len(chars)
+#
+#dump(chars, file('chars', 'w+'))
 #----------------------------------------------------------------
-chars = load(file('chars', 'r'))
+chars = load(file('chars', 'r'))[:500]
 learned = []
 learning_set = []
 test_set = []
 
 for char in chars:
-    if learned.count(char.value) > 80:
+    if learned.count(char.value) > 12:
         test_set.append(char)
     else:
         learning_set.append(char)
         learned.append(char.value)
 
+#print 'Learning set:', [c.value for c in learning_set]
+#print 'Test set:', [c.value for c in test_set]
 dump(learning_set, file('learning_set', 'w+'))
 dump(test_set, file('test_set', 'w+'))
 #----------------------------------------------------------------
@@ -52,7 +54,7 @@ for i, char in enumerate(test_set):
     prediction = classifier.classify(char)
 
     if char.value == prediction:
-        print ':) ------> Successfully recognized "%s"' % char.value,
+        print ':-----> Successfully recognized "%s"' % char.value,
         matches += 1
     else:
         print ':( Expected character "%s", got "%s"' \