Przeglądaj źródła

Merged classifier test scripts.

Taddeus Kroes 14 lat temu
rodzic
commit
b91057fc8c
4 zmienionych plików z 45 dodań i 62 usunięć
  1. 7 3
      src/Classifier.py
  2. 1 1
      src/find_svm_params.py
  3. 37 19
      src/run_classifier.py
  4. 0 39
      src/test_classifier.py

+ 7 - 3
src/Classifier.py

@@ -3,7 +3,8 @@ from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
 
 
 class Classifier:
-    def __init__(self, c=None, gamma=None, filename=None, neighbours=3):
+    def __init__(self, c=None, gamma=None, filename=None, neighbours=3, \
+            verbose=0):
         self.neighbours = neighbours
 
         if filename:
@@ -18,6 +19,8 @@ class Classifier:
             self.param.gamma = gamma  # Parameter for radial kernel
             self.model = None
 
+        self.verbose = verbose
+
     def save(self, filename):
         """Save the SVM model in the given filename."""
         svm_save_model(filename, self.model)
@@ -30,8 +33,9 @@ class Classifier:
         l = len(learning_set)
 
         for i, char in enumerate(learning_set):
-            print 'Found "%s"  --  %d of %d (%d%% done)' \
-                  % (char.value, i + 1, l, int(100 * (i + 1) / l))
+            if self.verbose:
+                print 'Found "%s"  --  %d of %d (%d%% done)' \
+                    % (char.value, i + 1, l, round(100 * (i + 1) / l))
             classes.append(float(ord(char.value)))
             #features.append(char.get_feature_vector())
             char.get_single_cell_feature_vector(self.neighbours)

+ 1 - 1
src/find_svm_params.py

@@ -86,7 +86,7 @@ i = 0
 
 for c in C:
     for y in Y:
-        classifier = Classifier(c=c, gamma=y, neighbours=neighbours)
+        classifier = Classifier(c=c, gamma=y, neighbours=neighbours, verbose=1)
         classifier.train(learning_set)
         result = classifier.test(test_set)
 

+ 37 - 19
src/run_classifier.py

@@ -1,7 +1,8 @@
 #!/usr/bin/python
 from cPickle import load
 from sys import argv, exit
-from pylab import imsave
+from pylab import imsave, plot, subplot, imshow, show, axis, title
+from math import sqrt, ceil
 import os
 
 from Classifier import Classifier
@@ -25,39 +26,56 @@ print 'Loading test set...'
 test_set = load(file(test_set_file, 'r'))
 l = len(test_set)
 matches = 0
-classified = {}
+#classified = {}
+classified = []
 
 for i, char in enumerate(test_set):
     prediction = classifier.classify(char, char.value)
 
     if char.value != prediction:
-        key = '%s_as_%s' % (char.value, prediction)
+        classified.append((char, prediction))
 
-        if key not in classified:
-            classified[key] = [char]
-        else:
-            classified[key].append(char)
+        #key = '%s_as_%s' % (char.value, prediction)
+
+        #if key not in classified:
+        #    classified[key] = [char]
+        #else:
+        #    classified[key].append(char)
 
         print '"%s" was classified as "%s"' \
                 % (char.value, prediction)
     else:
         matches += 1
 
-    print '%d of %d (%d%% done)' % (i + 1, l, int(100 * (i + 1) / l))
+    print '%d of %d (%d%% done)' % (i + 1, l, round(100 * (i + 1) / l))
 
 print '\n%d matches (%d%%), %d fails' % (matches, \
-        int(100 * matches / l), \
+        round(100 * matches / l), \
         len(test_set) - matches)
 
-print 'Saving faulty classified characters...'
-folder = '../images/faulty/'
+# Show a grid plot of all faulty classified characters
+print 'Plotting faulty classified characters...'
+rows = int(ceil(sqrt(l - matches)))
+columns = int(ceil((l - matches) / float(rows)))
 
-if not os.path.exists(folder):
-    os.mkdir(folder)
+for i, pair in enumerate(classified):
+    char, prediction = pair
+    subplot(rows, columns, i + 1)
+    title('%s as %s' % (char.value, prediction))
+    imshow(char.image.data, cmap='gray')
+    axis('off')
 
-for filename, chars in classified.iteritems():
-    if len(chars) == 1:
-        imsave('%s%s' % (folder, filename), char.image.data, cmap='gray')
-    else:
-        for i, char in enumerate(chars):
-            imsave('%s%s_%d' % (folder, filename, i), char.image.data, cmap='gray')
+show()
+
+#print 'Saving faulty classified characters...'
+#folder = '../images/faulty/'
+#
+#if not os.path.exists(folder):
+#    os.mkdir(folder)
+#
+#for filename, chars in classified.iteritems():
+#    if len(chars) == 1:
+#        imsave('%s%s' % (folder, filename), char.image.data, cmap='gray')
+#    else:
+#        for i, char in enumerate(chars):
+#            imsave('%s%s_%d' % (folder, filename, i), char.image.data, cmap='gray')

+ 0 - 39
src/test_classifier.py

@@ -1,39 +0,0 @@
-#!/usr/bin/python
-from cPickle import load
-from sys import argv, exit
-
-from Classifier import Classifier
-
-if len(argv) < 5:
-    print 'Usage: python %s FILE_SUFFIX C GAMMA NEIGHBOURS' % argv[0]
-    exit(1)
-
-print 'Loading learning set'
-learning_set = load(file('learning_set%s.dat' % argv[1], 'r'))
-
-# Train the classifier with the learning set
-classifier = Classifier(c=float(argv[1]), \
-                        gamma=float(argv[2]), \
-                        neighbours=int(argv[3]))
-classifier.train(learning_set)
-
-print 'Loading test set...'
-test_set = load(file('test_set%s.dat' % argv[1], 'r'))
-l = len(test_set)
-matches = 0
-
-for i, char in enumerate(test_set):
-    prediction = classifier.classify(char, char.value)
-
-    if char.value == prediction:
-        print ':-----> Successfully recognized "%s"' % char.value,
-        matches += 1
-    else:
-        print ':( Expected character "%s", got "%s"' \
-                % (char.value, prediction),
-
-    print '  --  %d of %d (%d%% done)' % (i + 1, l, int(100 * (i + 1) / l))
-
-print '\n%d matches (%d%%), %d fails' % (matches, \
-        int(100 * matches / l), \
-        len(test_set) - matches)