Kaynağa Gözat

Merge branch 'master' of https://github.com/taddeus/licenseplates

unknown 14 yıl önce
ebeveyn
işleme
6103bbdd53
2 değiştirilmiş dosya ile 56 ekleme ve 2 silme
  1. 5 2
      src/Character.py
  2. 51 0
      src/Classifier.py

+ 5 - 2
src/Character.py

@@ -11,7 +11,7 @@ class Character:
 
 
     def set_corners(self):
     def set_corners(self):
         corners = self.get_children("quadrangle")
         corners = self.get_children("quadrangle")
-  
+
         self.corners = []
         self.corners = []
 
 
         for corner in corners:
         for corner in corners:
@@ -25,4 +25,7 @@ class Character:
         return dom.getElementsByTagName(node)[0]
         return dom.getElementsByTagName(node)[0]
 
 
     def get_children(self, node, dom=None):
     def get_children(self, node, dom=None):
-        return self.get_node(node, dom).childNodes
+        return self.get_node(node, dom).childNodes
+
+    def get_feature_vector(self):
+        pass

+ 51 - 0
src/Classifier.py

@@ -0,0 +1,51 @@
+from svmutil import svm_model, svm_problem, svm_parameter, svm_predict, LINEAR
+from cPicle import dump, load
+
+
+class Classifier:
+    def __init__(self, c=None, filename=None):
+        if filename:
+            # If a filename is given, load a modl from the fiven filename
+            f = file(filename, 'r')
+            self.model, self.param, self.character_map = load(f)
+            f.close()
+        else:
+            self.param = svm_parameter()
+            self.param.kernel_type = LINEAR
+            self.param.C = c
+            self.character_map = {}
+            self.model = None
+
+    def save(self, filename):
+        """Save the SVM model in the given filename."""
+        f = file(filename, 'w+')
+        dump((self.model, self.param, self.character_map), f)
+        f.close()
+
+    def train(self, learning_set):
+        """Train the classifier with a list of character objects that have
+        known values."""
+        classes = []
+        features = []
+
+        for char in learning_set:
+            # Map the character to an integer for use in the SVM model
+            if char.value not in self.character_map:
+                self.character_map[char.value] = len(self.character_map)
+
+            classes.append(self.character_map[char.value])
+            features.append(char.get_feature_vector())
+
+        problem = svm_problem(self.c, features)
+        self.model = svm_model(problem, self.param)
+
+        # Add prediction function that returns a numeric class prediction
+        self.model.predict = lambda self, x: svm_predict([0], [x], self)[0][0]
+
+    def classify(self, character):
+        """Classify a character object and assign its value."""
+        prediction = self.model.predict(character.get_feature_vector())
+
+        for value, svm_class in self.character_map.iteritems():
+            if svm_class == prediction:
+                return value