rulestestcase.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # This file is part of TRS (http://math.kompiler.org)
  2. #
  3. # TRS is free software: you can redistribute it and/or modify it under the
  4. # terms of the GNU Affero General Public License as published by the Free
  5. # Software Foundation, either version 3 of the License, or (at your option) any
  6. # later version.
  7. #
  8. # TRS is distributed in the hope that it will be useful, but WITHOUT ANY
  9. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  10. # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  11. # details.
  12. #
  13. # You should have received a copy of the GNU Affero General Public License
  14. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
  15. import unittest
  16. import doctest
  17. from src.node import ExpressionNode
  18. from src.parser import Parser
  19. from src.validation import validate
  20. from tests.parser import ParserWrapper
  21. def tree(exp, **kwargs):
  22. return ParserWrapper(Parser, **kwargs).run([exp])
  23. def rewrite(exp, **kwargs):
  24. wrapper = ParserWrapper(Parser, **kwargs)
  25. wrapper.run([exp])
  26. return wrapper.parser.rewrite(check_implicit=False)
  27. class RulesTestCase(unittest.TestCase):
  28. def assertDoctests(self, module):
  29. self.assertEqual(doctest.testmod(m=module)[0], 0,
  30. 'There are failed doctests.')
  31. def assertEqualPos(self, possibilities, expected):
  32. self.assertEqual(len(possibilities), len(expected))
  33. for p, e in zip(possibilities, expected):
  34. self.assertEqual(p.root, e.root)
  35. if p.args == None: # pragma: nocover
  36. self.assertIsNone(e.args)
  37. elif e.args == None: # pragma: nocover
  38. self.assertIsNone(p.args)
  39. else:
  40. for pair in zip(p.args, e.args):
  41. self.assertEqual(*pair)
  42. self.assertEqual(p, e)
  43. def assertEqualNodes(self, a, b):
  44. if not isinstance(a, ExpressionNode):
  45. return self.assertEqual(a, b)
  46. self.assertIsInstance(b, ExpressionNode)
  47. self.assertEqual(a.op, b.op)
  48. for ca, cb in zip(a, b):
  49. self.assertEqualNodes(ca, cb)
  50. def assertRewrite(self, rewrite_chain):
  51. try:
  52. for i, exp in enumerate(rewrite_chain[:-1]):
  53. self.assertMultiLineEqual(str(rewrite(exp)),
  54. str(rewrite_chain[i + 1]))
  55. except AssertionError as e: # pragma: nocover
  56. msg = e.args[0]
  57. msg += '-' * 30 + '\n'
  58. msg += 'rewrite failed: "%s" -> "%s"\n' \
  59. % (str(exp), str(rewrite_chain[i + 1]))
  60. msg += 'rewrite chain: ---\n'
  61. chain = []
  62. for j, c in enumerate(rewrite_chain):
  63. if i == j:
  64. chain.append('%2d %s <-- error' % (j, str(c)))
  65. else:
  66. chain.append('%2d %s' % (j, str(c)))
  67. e.message = msg + '\n'.join(chain)
  68. e.args = (e.message,) + e.args[1:]
  69. raise
  70. def assertValidate(self, exp, result):
  71. self.assertTrue(validate(exp, result),
  72. 'Validation failed: %s !=> %s')