run_classifier.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/python
  2. from sys import argv, exit
  3. from pylab import subplot, imshow, show, axis, title
  4. from math import sqrt, ceil
  5. from create_characters import load_test_set
  6. from create_classifier import load_classifier
  7. if len(argv) < 3:
  8. print 'Usage: python %s NEIGHBOURS BLUR_SCALE [ C GAMMA ]' % argv[0]
  9. exit(1)
  10. neighbours = int(argv[1])
  11. blur_scale = float(argv[2])
  12. # Load classifier
  13. if len(argv) > 4:
  14. c = float(argv[3])
  15. gamma = float(argv[4])
  16. classifier = load_classifier(neighbours, blur_scale, c=c, gamma=gamma, \
  17. verbose=1)
  18. else:
  19. classifier = load_classifier(neighbours, blur_scale, verbose=1)
  20. # Load test set
  21. test_set = load_test_set(neighbours, blur_scale, verbose=1)
  22. # Classify each character in the test set, remembering all faulty
  23. # classified characters
  24. l = len(test_set)
  25. matches = 0
  26. classified = []
  27. for i, char in enumerate(test_set):
  28. prediction = classifier.classify(char, char.value)
  29. if char.value != prediction:
  30. classified.append((char, prediction))
  31. print '"%s" was classified as "%s"' \
  32. % (char.value, prediction)
  33. else:
  34. matches += 1
  35. print '%d of %d (%d%% done)' % (i + 1, l, round(100 * (i + 1) / l))
  36. print '\n%d matches (%.1f%%), %d fails' % (matches, \
  37. 100.0 * matches / l, len(test_set) - matches)
  38. # Show a grid plot of all faulty classified characters
  39. print 'Plotting faulty classified characters...'
  40. rows = int(ceil(sqrt(l - matches)))
  41. columns = int(ceil((l - matches) / float(rows)))
  42. for i, pair in enumerate(classified):
  43. char, prediction = pair
  44. subplot(rows, columns, i + 1)
  45. title('%s as %s' % (char.value, prediction))
  46. imshow(char.image.data, cmap='gray')
  47. axis('off')
  48. show()