rulestestcase.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import unittest
  2. import doctest
  3. from src.node import ExpressionNode
  4. from src.parser import Parser
  5. from tests.parser import ParserWrapper
  6. def tree(exp, **kwargs):
  7. return ParserWrapper(Parser, **kwargs).run([exp])
  8. def rewrite(exp, **kwargs):
  9. return ParserWrapper(Parser, **kwargs).run([exp, '@'])
  10. class RulesTestCase(unittest.TestCase):
  11. def assertDoctests(self, module):
  12. self.assertEqual(doctest.testmod(m=module)[0], 0,
  13. 'There are failed doctests.')
  14. def assertEqualPos(self, possibilities, expected):
  15. self.assertEqual(len(possibilities), len(expected))
  16. for p, e in zip(possibilities, expected):
  17. self.assertEqual(p.root, e.root)
  18. if p.args == None: # pragma: nocover
  19. self.assertIsNone(e.args)
  20. elif e.args == None: # pragma: nocover
  21. self.assertIsNone(p.args)
  22. else:
  23. for pair in zip(p.args, e.args):
  24. self.assertEqual(*pair)
  25. self.assertEqual(p, e)
  26. def assertEqualNodes(self, a, b):
  27. if not isinstance(a, ExpressionNode):
  28. return self.assertEqual(a, b)
  29. self.assertIsInstance(b, ExpressionNode)
  30. self.assertEqual(a.op, b.op)
  31. for ca, cb in zip(a, b):
  32. self.assertEqualNodes(ca, cb)
  33. def assertRewrite(self, rewrite_chain):
  34. try:
  35. for i, exp in enumerate(rewrite_chain[:-1]):
  36. self.assertMultiLineEqual(str(rewrite(exp)),
  37. str(rewrite_chain[i + 1]))
  38. except AssertionError as e: # pragma: nocover
  39. msg = e.args[0]
  40. msg += '-' * 30 + '\n'
  41. msg += 'rewrite failed: "%s" -> "%s"\n' \
  42. % (str(exp), str(rewrite_chain[i + 1]))
  43. msg += 'rewrite chain: ---\n'
  44. chain = []
  45. for j, c in enumerate(rewrite_chain):
  46. if i == j:
  47. chain.append('%2d %s <-- error' % (j, str(c)))
  48. else:
  49. chain.append('%2d %s' % (j, str(c)))
  50. e.message = msg + '\n'.join(chain)
  51. e.args = (e.message,) + e.args[1:]
  52. raise