rulestestcase.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import unittest
  2. import doctest
  3. from src.node import ExpressionNode
  4. from src.parser import Parser
  5. from src.validation import validate
  6. from tests.parser import ParserWrapper
  7. def tree(exp, **kwargs):
  8. return ParserWrapper(Parser, **kwargs).run([exp])
  9. def rewrite(exp, **kwargs):
  10. wrapper = ParserWrapper(Parser, **kwargs)
  11. wrapper.run([exp])
  12. return wrapper.parser.rewrite(check_implicit=False)
  13. class RulesTestCase(unittest.TestCase):
  14. def assertDoctests(self, module):
  15. self.assertEqual(doctest.testmod(m=module)[0], 0,
  16. 'There are failed doctests.')
  17. def assertEqualPos(self, possibilities, expected):
  18. self.assertEqual(len(possibilities), len(expected))
  19. for p, e in zip(possibilities, expected):
  20. self.assertEqual(p.root, e.root)
  21. if p.args == None: # pragma: nocover
  22. self.assertIsNone(e.args)
  23. elif e.args == None: # pragma: nocover
  24. self.assertIsNone(p.args)
  25. else:
  26. for pair in zip(p.args, e.args):
  27. self.assertEqual(*pair)
  28. self.assertEqual(p, e)
  29. def assertEqualNodes(self, a, b):
  30. if not isinstance(a, ExpressionNode):
  31. return self.assertEqual(a, b)
  32. self.assertIsInstance(b, ExpressionNode)
  33. self.assertEqual(a.op, b.op)
  34. for ca, cb in zip(a, b):
  35. self.assertEqualNodes(ca, cb)
  36. def assertRewrite(self, rewrite_chain):
  37. try:
  38. for i, exp in enumerate(rewrite_chain[:-1]):
  39. self.assertMultiLineEqual(str(rewrite(exp)),
  40. str(rewrite_chain[i + 1]))
  41. except AssertionError as e: # pragma: nocover
  42. msg = e.args[0]
  43. msg += '-' * 30 + '\n'
  44. msg += 'rewrite failed: "%s" -> "%s"\n' \
  45. % (str(exp), str(rewrite_chain[i + 1]))
  46. msg += 'rewrite chain: ---\n'
  47. chain = []
  48. for j, c in enumerate(rewrite_chain):
  49. if i == j:
  50. chain.append('%2d %s <-- error' % (j, str(c)))
  51. else:
  52. chain.append('%2d %s' % (j, str(c)))
  53. e.message = msg + '\n'.join(chain)
  54. e.args = (e.message,) + e.args[1:]
  55. raise
  56. def assertValidate(self, exp, result):
  57. self.assertTrue(validate(exp, result),
  58. 'Validation failed: %s !=> %s')