Sfoglia il codice sorgente

Cleaned up test scripts.

Taddeus Kroes 14 anni fa
parent
commit
92b7ce5979
2 ha cambiato i file con 11 aggiunte e 14 eliminazioni
  1. 7 10
      src/find_svm_params.py
  2. 4 4
      src/test_classifier.py

+ 7 - 10
src/find_svm_params.py

@@ -2,10 +2,8 @@
 from cPickle import load
 from Classifier import Classifier
 
-#C = [float(2 ** p) for p in xrange(-5, 16, 2)]
-#Y = [float(2 ** p) for p in xrange(-15, 4, 2)]
-C = [float(2 ** p) for p in xrange(1, 16, 2)]
-Y = [float(2 ** p) for p in xrange(-13, 4, 2)]
+C = [float(2 ** p) for p in xrange(-5, 16, 2)]
+Y = [float(2 ** p) for p in xrange(-15, 4, 2)]
 best_classifier = None
 
 print 'Loading learning set...'
@@ -17,7 +15,7 @@ print 'Test set:', [c.value for c in test_set]
 
 # Perform a grid-search on different combinations of soft margin and gamma
 results = []
-maximum = (0, 0, 0)
+best = (0,)
 i = 0
 
 for c in C:
@@ -26,9 +24,8 @@ for c in C:
         classifier.train(learning_set)
         result = classifier.test(test_set)
 
-        if result > maximum[2]:
-            maximum = (c, y, result)
-            best_classifier = classifier
+        if result > best[0]:
+            best = (result, c, y, classifier)
 
         results.append(result)
         i += 1
@@ -52,6 +49,6 @@ for c in C:
 
     print
 
-print '\nmax:', maximum
+print '\nBest result: %.3f%% for C = %f and gamma = %f' % best[:3]
 
-best_classifier.save('best_classifier.dat')
+best[3].save('classifier.dat')

+ 4 - 4
src/test_classifier.py

@@ -41,11 +41,11 @@ learning_set = load(file('learning_set.dat', 'r'))
 # Train the classifier with the learning set
 classifier = Classifier(c=512, gamma=.125, cell_size=12)
 classifier.train(learning_set)
-#classifier.save('classifier')
-#print 'Saved classifier'
+classifier.save('classifier.dat')
+print 'Saved classifier'
 #----------------------------------------------------------------
-#print 'Loading classifier'
-#classifier = Classifier(filename='classifier')
+print 'Loading classifier'
+classifier = Classifier(filename='classifier.dat')
 print 'Loading test set'
 test_set = load(file('test_set.dat', 'r'))
 l = len(test_set)