rulestestcase.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # This file is part of TRS (http://math.kompiler.org)
  2. #
  3. # TRS is free software: you can redistribute it and/or modify it under the
  4. # terms of the GNU Affero General Public License as published by the Free
  5. # Software Foundation, either version 3 of the License, or (at your option) any
  6. # later version.
  7. #
  8. # TRS is distributed in the hope that it will be useful, but WITHOUT ANY
  9. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  10. # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  11. # details.
  12. #
  13. # You should have received a copy of the GNU Affero General Public License
  14. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
  15. import unittest
  16. import doctest
  17. from src.node import ExpressionNode
  18. from src.parser import Parser
  19. from src.validation import validate, VALIDATE_SUCCESS, \
  20. VALIDATE_FAILURE, VALIDATE_NOPROGRESS
  21. from tests.parser import ParserWrapper
  22. def tree(exp, **kwargs):
  23. return ParserWrapper(Parser, **kwargs).run([exp])
  24. def rewrite(exp, **kwargs):
  25. wrapper = ParserWrapper(Parser, **kwargs)
  26. wrapper.run([exp])
  27. return wrapper.parser.rewrite(check_implicit=False)
  28. class RulesTestCase(unittest.TestCase):
  29. def assertDoctests(self, module):
  30. self.assertEqual(doctest.testmod(m=module)[0], 0,
  31. 'There are failed doctests.')
  32. def assertEqualPos(self, possibilities, expected):
  33. self.assertEqual(len(possibilities), len(expected))
  34. for p, e in zip(possibilities, expected):
  35. self.assertEqual(p.root, e.root)
  36. if p.args == None: # pragma: nocover
  37. self.assertIsNone(e.args)
  38. elif e.args == None: # pragma: nocover
  39. self.assertIsNone(p.args)
  40. else:
  41. for pair in zip(p.args, e.args):
  42. self.assertEqual(*pair)
  43. self.assertEqual(p, e)
  44. def assertEqualNodes(self, a, b):
  45. if not isinstance(a, ExpressionNode):
  46. return self.assertEqual(a, b)
  47. self.assertIsInstance(b, ExpressionNode)
  48. self.assertEqual(a.op, b.op)
  49. for ca, cb in zip(a, b):
  50. self.assertEqualNodes(ca, cb)
  51. def assertRewrite(self, rewrite_chain):
  52. try:
  53. for i, exp in enumerate(rewrite_chain[:-1]):
  54. self.assertMultiLineEqual(str(rewrite(exp)),
  55. str(rewrite_chain[i + 1]))
  56. except AssertionError as e: # pragma: nocover
  57. msg = e.args[0]
  58. msg += '-' * 30 + '\n'
  59. msg += 'rewrite failed: "%s" -> "%s"\n' \
  60. % (str(exp), str(rewrite_chain[i + 1]))
  61. msg += 'rewrite chain: ---\n'
  62. chain = []
  63. for j, c in enumerate(rewrite_chain):
  64. if i == j:
  65. chain.append('%2d %s <-- error' % (j, str(c)))
  66. else:
  67. chain.append('%2d %s' % (j, str(c)))
  68. e.message = msg + '\n'.join(chain)
  69. e.args = (e.message,) + e.args[1:]
  70. raise
  71. def assertEvaluates(self, exp, result):
  72. node = tree(exp)
  73. while node:
  74. s = str(node)
  75. if s == result:
  76. return True
  77. node = rewrite(s)
  78. raise AssertionError('`%s` does not rewrite to `%s`' % (exp, result))
  79. def assertValidateSuccess(self, a, b):
  80. self.assertEqual(validate(a, b), VALIDATE_SUCCESS,
  81. 'Validation failed: %s !=> %s' % (a, b))
  82. def assertValidateFailure(self, a, b):
  83. self.assertEqual(validate(a, b), VALIDATE_FAILURE,
  84. 'Validation dit not fail: %s => %s' % (a, b))
  85. def assertValidateNoprogress(self, a, b):
  86. self.assertEqual(validate(a, b), VALIDATE_NOPROGRESS, 'Validation '
  87. 'did detect progress or failed for %s => %s' % (a, b))