find_svm_params.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/python
  2. from os import listdir
  3. from os.path import exists
  4. from cPickle import load, dump
  5. from sys import argv, exit
  6. from GrayscaleImage import GrayscaleImage
  7. from NormalizedCharacterImage import NormalizedCharacterImage
  8. from Character import Character
  9. from Classifier import Classifier
  10. if len(argv) < 3:
  11. print 'Usage: python %s NEIGHBOURS BLUR_SCALE' % argv[0]
  12. exit(1)
  13. neighbours = int(argv[1])
  14. blur_scale = float(argv[2])
  15. suffix = '_%s_%s' % (blur_scale, neighbours)
  16. chars_file = 'characters%s.dat' % suffix
  17. learning_set_file = 'learning_set%s.dat' % suffix
  18. test_set_file = 'test_set%s.dat' % suffix
  19. classifier_file = 'classifier%s.dat' % suffix
  20. results_file = 'results%s.txt' % suffix
  21. # Load characters
  22. if exists(chars_file):
  23. print 'Loading characters...'
  24. chars = load(open(chars_file, 'r'))
  25. else:
  26. print 'Going to generate character objects...'
  27. chars = []
  28. for char in sorted(listdir('../images/LearningSet')):
  29. for image in sorted(listdir('../images/LearningSet/' + char)):
  30. f = '../images/LearningSet/' + char + '/' + image
  31. image = GrayscaleImage(f)
  32. norm = NormalizedCharacterImage(image, blur=blur_scale, height=42)
  33. #imshow(norm.data, cmap='gray'); show()
  34. character = Character(char, [], norm)
  35. character.get_single_cell_feature_vector(neighbours)
  36. chars.append(character)
  37. print char
  38. print 'Saving characters...'
  39. dump(chars, open(chars_file, 'w+'))
  40. # Load learning set and test set
  41. if exists(learning_set_file):
  42. print 'Loading learning set...'
  43. learning_set = load(open(learning_set_file, 'r'))
  44. print 'Learning set:', [c.value for c in learning_set]
  45. print 'Loading test set...'
  46. test_set = load(open(test_set_file, 'r'))
  47. print 'Test set:', [c.value for c in test_set]
  48. else:
  49. print 'Going to generate learning set and test set...'
  50. learning_set = []
  51. test_set = []
  52. learned = []
  53. for char in chars:
  54. if learned.count(char.value) == 70:
  55. test_set.append(char)
  56. else:
  57. learning_set.append(char)
  58. learned.append(char.value)
  59. print 'Learning set:', [c.value for c in learning_set]
  60. print '\nTest set:', [c.value for c in test_set]
  61. print '\nSaving learning set...'
  62. dump(learning_set, file(learning_set_file, 'w+'))
  63. print 'Saving test set...'
  64. dump(test_set, file(test_set_file, 'w+'))
  65. # Perform a grid-search to find the optimal values for C and gamma
  66. C = [float(2 ** p) for p in xrange(-5, 16, 2)]
  67. Y = [float(2 ** p) for p in xrange(-15, 4, 2)]
  68. results = []
  69. best = (0,)
  70. i = 0
  71. for c in C:
  72. for y in Y:
  73. classifier = Classifier(c=c, gamma=y, neighbours=neighbours)
  74. classifier.train(learning_set)
  75. result = classifier.test(test_set)
  76. if result > best[0]:
  77. best = (result, c, y, classifier)
  78. results.append(result)
  79. i += 1
  80. print '%d of %d, c = %f, gamma = %f, result = %d%%' \
  81. % (i, len(C) * len(Y), c, y, int(round(result * 100)))
  82. i = 0
  83. s = ' c\y'
  84. for y in Y:
  85. s += '| %f' % y
  86. s += '\n'
  87. for c in C:
  88. s += ' %7s' % c
  89. for y in Y:
  90. s += '| %8d' % int(round(results[i] * 100))
  91. i += 1
  92. s += '\n'
  93. s += '\nBest result: %.3f%% for C = %f and gamma = %f' % best[:3]
  94. print 'Saving results...'
  95. f = open(results_file, 'w+')
  96. f.write(s + '\n')
  97. f.close()
  98. print 'Saving best classifier...'
  99. best[3].save(classifier_file)
  100. print '\n' + s