Explorar o código

Added script to find SVM parameters.

Taddeus Kroes %!s(int64=14) %!d(string=hai) anos
pai
achega
3158862e20
Modificáronse 2 ficheiros con 38 adicións e 2 borrados
  1. 17 2
      src/Classifier.py
  2. 21 0
      src/find_svm_params.py

+ 17 - 2
src/Classifier.py

@@ -3,14 +3,17 @@ from svmutil import svm_train, svm_problem, svm_parameter, svm_predict, \
 
 
 
 
 class Classifier:
 class Classifier:
-    def __init__(self, c=None, filename=None):
+    def __init__(self, c=None, gamma=None, filename=None):
         if filename:
         if filename:
             # If a filename is given, load a model from the given filename
             # If a filename is given, load a model from the given filename
             self.model = svm_load_model(filename)
             self.model = svm_load_model(filename)
+        elif c == None or gamma == None:
+            raise Exception('Please specify both C and gamma.')
         else:
         else:
             self.param = svm_parameter()
             self.param = svm_parameter()
             self.param.kernel_type = 2  # Radial kernel type
             self.param.kernel_type = 2  # Radial kernel type
-            self.param.C = c
+            self.param.C = c  # Soft margin
+            self.param.gamma = gamma  # Parameter for radial kernel
             self.model = None
             self.model = None
 
 
     def save(self, filename):
     def save(self, filename):
@@ -33,6 +36,18 @@ class Classifier:
         problem = svm_problem(classes, features)
         problem = svm_problem(classes, features)
         self.model = svm_train(problem, self.param)
         self.model = svm_train(problem, self.param)
 
 
+    def test(self, test_set):
+        """Test the classifier with the given test set and return the score."""
+        matches = 0
+
+        for char in test_set:
+            prediction = self.classify(char)
+
+            if char.value == prediction:
+                matches += 1
+
+        return float(matches) / len(test_set)
+
     def classify(self, character):
     def classify(self, character):
         """Classify a character object, return its value."""
         """Classify a character object, return its value."""
         predict = lambda x: svm_predict([0], [x], self.model)[0][0]
         predict = lambda x: svm_predict([0], [x], self.model)[0][0]

+ 21 - 0
src/find_svm_params.py

@@ -0,0 +1,21 @@
+C = [2 ** p for p in xrange(-5, 16, 2)]:
+Y = [2 ** p for p in xrange(-15, 4, 2)]
+best_result = 0
+best_classifier = None
+
+learning_set = load(file('learning_set', 'r'))
+test_set = load(file('test_set', 'r'))
+
+# Perform a grid-search on different combinations of soft margin and gamma
+for c in C:
+    for y in Y:
+        classifier = Classifier(c=c, gamma=y)
+        classifier.train(learning_set)
+        result = classifier.test(test_set)
+
+        if result > best_result:
+            best_classifier = classifier
+
+        print 'c = %f, gamma = %f, result = %d%%' % (c, y, int(result * 100))
+
+best_classifier.save('best_classifier')