find_svm_params.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/python
  2. import os
  3. from sys import argv, exit
  4. from Classifier import Classifier
  5. from data import DATA_FOLDER, RESULTS_FOLDER
  6. from create_characters import load_learning_set, load_test_set
  7. if len(argv) < 3:
  8. print 'Usage: python %s NEIGHBOURS BLUR_SCALE' % argv[0]
  9. exit(1)
  10. neighbours = int(argv[1])
  11. blur_scale = float(argv[2])
  12. suffix = '_%s_%s' % (blur_scale, neighbours)
  13. if not os.path.exists(RESULTS_FOLDER):
  14. os.mkdir(RESULTS_FOLDER)
  15. classifier_file = DATA_FOLDER + 'classifier%s.dat' % suffix
  16. results_file = '%sresult%s.txt' % (RESULTS_FOLDER, suffix)
  17. # Load learning set and test set
  18. learning_set = load_learning_set(neighbours, blur_scale, verbose=1)
  19. test_set = load_test_set(neighbours, blur_scale, verbose=1)
  20. # Perform a grid-search to find the optimal values for C and gamma
  21. C = [float(2 ** p) for p in xrange(-5, 16, 2)]
  22. Y = [float(2 ** p) for p in xrange(-15, 4, 2)]
  23. results = []
  24. best = (0,)
  25. i = 0
  26. for c in C:
  27. for y in Y:
  28. classifier = Classifier(c=c, gamma=y, neighbours=neighbours, verbose=1)
  29. classifier.train(learning_set)
  30. result = classifier.test(test_set)
  31. if result > best[0]:
  32. best = (result, c, y, classifier)
  33. results.append(result)
  34. i += 1
  35. print '%d of %d, c = %f, gamma = %f, result = %d%%' \
  36. % (i, len(C) * len(Y), c, y, int(round(result * 100)))
  37. i = 0
  38. s = ' c\y'
  39. for y in Y:
  40. s += ' | %f' % y
  41. s += '\n'
  42. for c in C:
  43. s += ' %7s' % c
  44. for y in Y:
  45. s += ' | %8d' % int(round(results[i] * 100))
  46. i += 1
  47. s += '\n'
  48. s += '\nBest result: %.3f%% for C = %f and gamma = %f' \
  49. % ((best[0] * 100,) + best[1:3])
  50. print 'Saving results...'
  51. f = open(results_file, 'w+')
  52. f.write(s + '\n')
  53. f.close()
  54. print 'Saving best classifier...'
  55. best[3].save(classifier_file)
  56. print '\n' + s