rulestestcase.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import unittest
  2. from src.node import ExpressionNode
  3. from src.parser import Parser
  4. from tests.parser import ParserWrapper
  5. def tree(exp, **kwargs):
  6. return ParserWrapper(Parser, **kwargs).run([exp])
  7. def rewrite(exp, **kwargs):
  8. return ParserWrapper(Parser, **kwargs).run([exp, '@'])
  9. class RulesTestCase(unittest.TestCase):
  10. def assertEqualPos(self, possibilities, expected):
  11. self.assertEqual(len(possibilities), len(expected))
  12. for p, e in zip(possibilities, expected):
  13. self.assertEqual(p.root, e.root)
  14. if p.args == None: # pragma: nocover
  15. self.assertIsNone(e.args)
  16. elif e.args == None: # pragma: nocover
  17. self.assertIsNone(p.args)
  18. else:
  19. for pair in zip(p.args, e.args):
  20. self.assertEqual(*pair)
  21. self.assertEqual(p, e)
  22. def assertEqualNodes(self, a, b):
  23. if not isinstance(a, ExpressionNode):
  24. return self.assertEqual(a, b)
  25. self.assertIsInstance(b, ExpressionNode)
  26. self.assertEqual(a.op, b.op)
  27. for ca, cb in zip(a, b):
  28. self.assertEqualNodes(ca, cb)
  29. def assertRewrite(self, rewrite_chain):
  30. try:
  31. for i, exp in enumerate(rewrite_chain[:-1]):
  32. self.assertMultiLineEqual(str(rewrite(exp)),
  33. str(rewrite_chain[i+1]))
  34. except AssertionError: # pragma: nocover
  35. print 'rewrite failed: "%s" -> "%s"' \
  36. % (str(exp), str(rewrite_chain[i+1]))
  37. print 'rewrite chain index: %d' % i
  38. print 'rewrite chain: ---'
  39. for i, c in enumerate(rewrite_chain):
  40. print '%2d %s' % (i, str(c))
  41. print '-' * 30
  42. raise