Commit b91057fc authored by Taddeus Kroes's avatar Taddeus Kroes

Merged classifier test scripts.

parent 45911a87
......@@ -3,7 +3,8 @@ from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
class Classifier:
def __init__(self, c=None, gamma=None, filename=None, neighbours=3):
def __init__(self, c=None, gamma=None, filename=None, neighbours=3, \
verbose=0):
self.neighbours = neighbours
if filename:
......@@ -18,6 +19,8 @@ class Classifier:
self.param.gamma = gamma # Parameter for radial kernel
self.model = None
self.verbose = verbose
def save(self, filename):
"""Save the SVM model in the given filename."""
svm_save_model(filename, self.model)
......@@ -30,8 +33,9 @@ class Classifier:
l = len(learning_set)
for i, char in enumerate(learning_set):
if self.verbose:
print 'Found "%s" -- %d of %d (%d%% done)' \
% (char.value, i + 1, l, int(100 * (i + 1) / l))
% (char.value, i + 1, l, round(100 * (i + 1) / l))
classes.append(float(ord(char.value)))
#features.append(char.get_feature_vector())
char.get_single_cell_feature_vector(self.neighbours)
......
......@@ -86,7 +86,7 @@ i = 0
for c in C:
for y in Y:
classifier = Classifier(c=c, gamma=y, neighbours=neighbours)
classifier = Classifier(c=c, gamma=y, neighbours=neighbours, verbose=1)
classifier.train(learning_set)
result = classifier.test(test_set)
......
#!/usr/bin/python
from cPickle import load
from sys import argv, exit
from pylab import imsave
from pylab import imsave, plot, subplot, imshow, show, axis, title
from math import sqrt, ceil
import os
from Classifier import Classifier
......@@ -25,39 +26,56 @@ print 'Loading test set...'
test_set = load(file(test_set_file, 'r'))
l = len(test_set)
matches = 0
classified = {}
#classified = {}
classified = []
for i, char in enumerate(test_set):
prediction = classifier.classify(char, char.value)
if char.value != prediction:
key = '%s_as_%s' % (char.value, prediction)
classified.append((char, prediction))
if key not in classified:
classified[key] = [char]
else:
classified[key].append(char)
#key = '%s_as_%s' % (char.value, prediction)
#if key not in classified:
# classified[key] = [char]
#else:
# classified[key].append(char)
print '"%s" was classified as "%s"' \
% (char.value, prediction)
else:
matches += 1
print '%d of %d (%d%% done)' % (i + 1, l, int(100 * (i + 1) / l))
print '%d of %d (%d%% done)' % (i + 1, l, round(100 * (i + 1) / l))
print '\n%d matches (%d%%), %d fails' % (matches, \
int(100 * matches / l), \
round(100 * matches / l), \
len(test_set) - matches)
print 'Saving faulty classified characters...'
folder = '../images/faulty/'
# Show a grid plot of all faulty classified characters
print 'Plotting faulty classified characters...'
rows = int(ceil(sqrt(l - matches)))
columns = int(ceil((l - matches) / float(rows)))
if not os.path.exists(folder):
os.mkdir(folder)
for i, pair in enumerate(classified):
char, prediction = pair
subplot(rows, columns, i + 1)
title('%s as %s' % (char.value, prediction))
imshow(char.image.data, cmap='gray')
axis('off')
for filename, chars in classified.iteritems():
if len(chars) == 1:
imsave('%s%s' % (folder, filename), char.image.data, cmap='gray')
else:
for i, char in enumerate(chars):
imsave('%s%s_%d' % (folder, filename, i), char.image.data, cmap='gray')
show()
#print 'Saving faulty classified characters...'
#folder = '../images/faulty/'
#
#if not os.path.exists(folder):
# os.mkdir(folder)
#
#for filename, chars in classified.iteritems():
# if len(chars) == 1:
# imsave('%s%s' % (folder, filename), char.image.data, cmap='gray')
# else:
# for i, char in enumerate(chars):
# imsave('%s%s_%d' % (folder, filename, i), char.image.data, cmap='gray')
#!/usr/bin/python
from cPickle import load
from sys import argv, exit
from Classifier import Classifier
if len(argv) < 5:
print 'Usage: python %s FILE_SUFFIX C GAMMA NEIGHBOURS' % argv[0]
exit(1)
print 'Loading learning set'
learning_set = load(file('learning_set%s.dat' % argv[1], 'r'))
# Train the classifier with the learning set
classifier = Classifier(c=float(argv[1]), \
gamma=float(argv[2]), \
neighbours=int(argv[3]))
classifier.train(learning_set)
print 'Loading test set...'
test_set = load(file('test_set%s.dat' % argv[1], 'r'))
l = len(test_set)
matches = 0
for i, char in enumerate(test_set):
prediction = classifier.classify(char, char.value)
if char.value == prediction:
print ':-----> Successfully recognized "%s"' % char.value,
matches += 1
else:
print ':( Expected character "%s", got "%s"' \
% (char.value, prediction),
print ' -- %d of %d (%d%% done)' % (i + 1, l, int(100 * (i + 1) / l))
print '\n%d matches (%d%%), %d fails' % (matches, \
int(100 * matches / l), \
len(test_set) - matches)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment