Commit 24b7b1d7 authored by Jayke Meijer's avatar Jayke Meijer

Merge branch 'master' of github.com:taddeus/licenseplates

parents ccfd5c3b 9749ba2d
from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \ from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
svm_save_model, svm_load_model svm_save_model, svm_load_model
from cPickle import dump, load
class Classifier: class Classifier:
def __init__(self, c=None, filename=None): def __init__(self, c=None, filename=None):
if filename: if filename:
# If a filename is given, load a modl from the fiven filename # If a filename is given, load a model from the given filename
self.model = svm_load_model(filename + '-model') self.model = svm_load_model(filename)
f = file(filename + '-characters', 'r')
self.character_map = load(f)
f.close()
else: else:
self.param = svm_parameter() self.param = svm_parameter()
self.param.kernel_type = 2 self.param.kernel_type = 2 # Radial kernel type
self.param.C = c self.param.C = c
self.character_map = {}
self.model = None self.model = None
def save(self, filename): def save(self, filename):
"""Save the SVM model in the given filename.""" """Save the SVM model in the given filename."""
svm_save_model(filename + '-model', self.model) svm_save_model(filename, self.model)
f = file(filename + '-characters', 'w+')
dump(self.character_map, f)
f.close()
def train(self, learning_set): def train(self, learning_set):
"""Train the classifier with a list of character objects that have """Train the classifier with a list of character objects that have
...@@ -34,22 +26,16 @@ class Classifier: ...@@ -34,22 +26,16 @@ class Classifier:
for i, char in enumerate(learning_set): for i, char in enumerate(learning_set):
print 'Training "%s" -- %d of %d (%d%% done)' \ print 'Training "%s" -- %d of %d (%d%% done)' \
% (char.value, i + 1, l, int(100 * (i + 1) / l)) % (char.value, i + 1, l, int(100 * (i + 1) / l))
# Map the character to an integer for use in the SVM model classes.append(float(ord(char.value)))
if char.value not in self.character_map:
self.character_map[char.value] = len(self.character_map)
classes.append(self.character_map[char.value])
features.append(char.get_feature_vector()) features.append(char.get_feature_vector())
problem = svm_problem(classes, features) problem = svm_problem(classes, features)
self.model = svm_train(problem, self.param) self.model = svm_train(problem, self.param)
def classify(self, character): def classify(self, character):
"""Classify a character object and assign its value.""" """Classify a character object, return its value."""
predict = lambda x: svm_predict([0], [x], self.model)[0][0] predict = lambda x: svm_predict([0], [x], self.model)[0][0]
prediction = predict(character.get_feature_vector()) prediction_class = predict(character.get_feature_vector())
for value, svm_class in self.character_map.iteritems(): return chr(int(prediction_class))
if svm_class == prediction:
return value
...@@ -21,18 +21,20 @@ print 'loaded %d chars' % len(chars) ...@@ -21,18 +21,20 @@ print 'loaded %d chars' % len(chars)
dump(chars, file('chars', 'w+')) dump(chars, file('chars', 'w+'))
#---------------------------------------------------------------- #----------------------------------------------------------------
chars = load(file('chars', 'r')) chars = load(file('chars', 'r'))[:500]
learned = [] learned = []
learning_set = [] learning_set = []
test_set = [] test_set = []
for char in chars: for char in chars:
if learned.count(char.value) > 80: if learned.count(char.value) > 12:
test_set.append(char) test_set.append(char)
else: else:
learning_set.append(char) learning_set.append(char)
learned.append(char.value) learned.append(char.value)
#print 'Learning set:', [c.value for c in learning_set]
#print 'Test set:', [c.value for c in test_set]
dump(learning_set, file('learning_set', 'w+')) dump(learning_set, file('learning_set', 'w+'))
dump(test_set, file('test_set', 'w+')) dump(test_set, file('test_set', 'w+'))
#---------------------------------------------------------------- #----------------------------------------------------------------
...@@ -52,7 +54,7 @@ for i, char in enumerate(test_set): ...@@ -52,7 +54,7 @@ for i, char in enumerate(test_set):
prediction = classifier.classify(char) prediction = classifier.classify(char)
if char.value == prediction: if char.value == prediction:
print ':) ------> Successfully recognized "%s"' % char.value, print ':-----> Successfully recognized "%s"' % char.value,
matches += 1 matches += 1
else: else:
print ':( Expected character "%s", got "%s"' \ print ':( Expected character "%s", got "%s"' \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment