Commit b9bc273e authored by Taddeus Kroes's avatar Taddeus Kroes

Replaced character map with char-to-int cast in SVM trainer and shrunk learning- and testsets.

parent a68fe0cd
...@@ -7,23 +7,16 @@ class Classifier: ...@@ -7,23 +7,16 @@ class Classifier:
def __init__(self, c=None, filename=None): def __init__(self, c=None, filename=None):
if filename: if filename:
# If a filename is given, load a modl from the fiven filename # If a filename is given, load a modl from the fiven filename
self.model = svm_load_model(filename + '-model') self.model = svm_load_model(filename)
f = file(filename + '-characters', 'r')
self.character_map = load(f)
f.close()
else: else:
self.param = svm_parameter() self.param = svm_parameter()
self.param.kernel_type = 2 self.param.kernel_type = 2
self.param.C = c self.param.C = c
self.character_map = {}
self.model = None self.model = None
def save(self, filename): def save(self, filename):
"""Save the SVM model in the given filename.""" """Save the SVM model in the given filename."""
svm_save_model(filename + '-model', self.model) svm_save_model(filename, self.model)
f = file(filename + '-characters', 'w+')
dump(self.character_map, f)
f.close()
def train(self, learning_set): def train(self, learning_set):
"""Train the classifier with a list of character objects that have """Train the classifier with a list of character objects that have
...@@ -35,11 +28,7 @@ class Classifier: ...@@ -35,11 +28,7 @@ class Classifier:
for i, char in enumerate(learning_set): for i, char in enumerate(learning_set):
print 'Training "%s" -- %d of %d (%d%% done)' \ print 'Training "%s" -- %d of %d (%d%% done)' \
% (char.value, i + 1, l, int(100 * (i + 1) / l)) % (char.value, i + 1, l, int(100 * (i + 1) / l))
# Map the character to an integer for use in the SVM model classes.append(float(ord(char.value)))
if char.value not in self.character_map:
self.character_map[char.value] = len(self.character_map)
classes.append(self.character_map[char.value])
features.append(char.get_feature_vector()) features.append(char.get_feature_vector())
problem = svm_problem(classes, features) problem = svm_problem(classes, features)
...@@ -48,8 +37,6 @@ class Classifier: ...@@ -48,8 +37,6 @@ class Classifier:
def classify(self, character): def classify(self, character):
"""Classify a character object and assign its value.""" """Classify a character object and assign its value."""
predict = lambda x: svm_predict([0], [x], self.model)[0][0] predict = lambda x: svm_predict([0], [x], self.model)[0][0]
prediction = predict(character.get_feature_vector()) prediction_class = predict(character.get_feature_vector())
for value, svm_class in self.character_map.iteritems(): return chr(int(prediction_class))
if svm_class == prediction:
return value
...@@ -3,36 +3,38 @@ from LicensePlate import LicensePlate ...@@ -3,36 +3,38 @@ from LicensePlate import LicensePlate
from Classifier import Classifier from Classifier import Classifier
from cPickle import dump, load from cPickle import dump, load
chars = [] #chars = []
#
for i in range(9): #for i in range(9):
for j in range(100): # for j in range(100):
try: # try:
filename = '%04d/00991_%04d%02d.info' % (i, i, j) # filename = '%04d/00991_%04d%02d.info' % (i, i, j)
print 'loading file "%s"' % filename # print 'loading file "%s"' % filename
plate = LicensePlate(i, j) # plate = LicensePlate(i, j)
#
if hasattr(plate, 'characters'): # if hasattr(plate, 'characters'):
chars.extend(plate.characters) # chars.extend(plate.characters)
except: # except:
print 'epic fail' # print 'epic fail'
#
print 'loaded %d chars' % len(chars) #print 'loaded %d chars' % len(chars)
#
dump(chars, file('chars', 'w+')) #dump(chars, file('chars', 'w+'))
#---------------------------------------------------------------- #----------------------------------------------------------------
chars = load(file('chars', 'r')) chars = load(file('chars', 'r'))[:500]
learned = [] learned = []
learning_set = [] learning_set = []
test_set = [] test_set = []
for char in chars: for char in chars:
if learned.count(char.value) > 80: if learned.count(char.value) > 12:
test_set.append(char) test_set.append(char)
else: else:
learning_set.append(char) learning_set.append(char)
learned.append(char.value) learned.append(char.value)
#print 'Learning set:', [c.value for c in learning_set]
#print 'Test set:', [c.value for c in test_set]
dump(learning_set, file('learning_set', 'w+')) dump(learning_set, file('learning_set', 'w+'))
dump(test_set, file('test_set', 'w+')) dump(test_set, file('test_set', 'w+'))
#---------------------------------------------------------------- #----------------------------------------------------------------
...@@ -52,7 +54,7 @@ for i, char in enumerate(test_set): ...@@ -52,7 +54,7 @@ for i, char in enumerate(test_set):
prediction = classifier.classify(char) prediction = classifier.classify(char)
if char.value == prediction: if char.value == prediction:
print ':) ------> Successfully recognized "%s"' % char.value, print ':-----> Successfully recognized "%s"' % char.value,
matches += 1 matches += 1
else: else:
print ':( Expected character "%s", got "%s"' \ print ':( Expected character "%s", got "%s"' \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment