test_classifier.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #!/usr/bin/python
  2. from xml_helper_functions import xml_to_LicensePlate
  3. from Classifier import Classifier
  4. from cPickle import dump, load
  5. chars = load(file('characters.dat', 'r'))
  6. learning_set = []
  7. test_set = []
  8. #s = {}
  9. #
  10. #for char in chars:
  11. # if char.value not in s:
  12. # s[char.value] = [char]
  13. # else:
  14. # s[char.value].append(char)
  15. #
  16. #for value, chars in s.iteritems():
  17. # learning_set += chars[::2]
  18. # test_set += chars[1::2]
  19. learned = []
  20. for char in chars:
  21. if learned.count(char.value) == 70:
  22. test_set.append(char)
  23. else:
  24. learning_set.append(char)
  25. learned.append(char.value)
  26. print 'Learning set:', [c.value for c in learning_set]
  27. print 'Test set:', [c.value for c in test_set]
  28. print 'Saving learning set...'
  29. dump(learning_set, file('learning_set.dat', 'w+'))
  30. print 'Saving test set...'
  31. dump(test_set, file('test_set.dat', 'w+'))
  32. #----------------------------------------------------------------
  33. print 'Loading learning set'
  34. learning_set = load(file('learning_set.dat', 'r'))
  35. # Train the classifier with the learning set
  36. classifier = Classifier(c=512, gamma=.125, cell_size=12)
  37. classifier.train(learning_set)
  38. classifier.save('classifier.dat')
  39. print 'Saved classifier'
  40. #----------------------------------------------------------------
  41. print 'Loading classifier'
  42. classifier = Classifier(filename='classifier.dat')
  43. print 'Loading test set'
  44. test_set = load(file('test_set.dat', 'r'))
  45. l = len(test_set)
  46. matches = 0
  47. for i, char in enumerate(test_set):
  48. prediction = classifier.classify(char, char.value)
  49. if char.value == prediction:
  50. print ':-----> Successfully recognized "%s"' % char.value,
  51. matches += 1
  52. else:
  53. print ':( Expected character "%s", got "%s"' \
  54. % (char.value, prediction),
  55. print ' -- %d of %d (%d%% done)' % (i + 1, l, int(100 * (i + 1) / l))
  56. print '\n%d matches (%d%%), %d fails' % (matches, \
  57. int(100 * matches / len(test_set)), \
  58. len(test_set) - matches)