test_node.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # This file is part of TRS (http://math.kompiler.org)
  2. #
  3. # TRS is free software: you can redistribute it and/or modify it under the
  4. # terms of the GNU Affero General Public License as published by the Free
  5. # Software Foundation, either version 3 of the License, or (at your option) any
  6. # later version.
  7. #
  8. # TRS is distributed in the hope that it will be useful, but WITHOUT ANY
  9. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  10. # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  11. # details.
  12. #
  13. # You should have received a copy of the GNU Affero General Public License
  14. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
  15. from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
  16. nary_node, get_scope, OP_ADD, infinity, absolute, sin, cos, tan, log, \
  17. ln, der, integral, int_def, eq
  18. from tests.rulestestcase import RulesTestCase, tree
  19. class TestNode(RulesTestCase):
  20. def setUp(self):
  21. self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
  22. self.n, self.f = tree('a + b + cd,f')
  23. (self.a, self.b), self.cd = self.n
  24. self.c, self.d = self.cd
  25. self.scope = Scope(self.n)
  26. def test___lt__(self):
  27. self.assertTrue(L(1) < L(2))
  28. self.assertFalse(L(1) < L(1))
  29. self.assertFalse(L(2) < L(1))
  30. self.assertTrue(L(2) < N('+', L(1), L(2)))
  31. self.assertFalse(N('+', L(1), L(2)) < L(1))
  32. self.assertTrue(N('^', L('a'), L(2)) < N('^', L('a'), L(3)))
  33. self.assertTrue(N('^', L(2), L('a')) < N('^', L(3), L('a')))
  34. self.assertTrue(N('*', L(2), N('^', L('a'), L('b')))
  35. < N('*', L(3), N('^', L('a'), L('b'))))
  36. self.assertFalse(N('^', L('a'), L(3)) < N('^', L('a'), L(2)))
  37. def test_is_op(self):
  38. self.assertTrue(N('+', *self.l[:2]).is_op(OP_ADD))
  39. self.assertFalse(N('-', *self.l[:2]).is_op(OP_ADD))
  40. def test_is_leaf(self):
  41. self.assertTrue(L(2).is_leaf)
  42. self.assertFalse(N('+', *self.l[:2]).is_leaf)
  43. def test_is_power(self):
  44. self.assertTrue(N('^', *self.l[2:]).is_power())
  45. self.assertFalse(N('+', *self.l[2:]).is_power())
  46. def test_is_power_exponent(self):
  47. self.assertTrue(N('^', *self.l[2:]).is_power(5))
  48. self.assertFalse(N('^', *self.l[2:]).is_power(2))
  49. def test_is_nary(self):
  50. self.assertTrue(N('+', *self.l[:2]).is_nary())
  51. self.assertTrue(N('-', *self.l[:2]).is_nary())
  52. self.assertTrue(N('*', *self.l[:2]).is_nary())
  53. self.assertFalse(N('^', *self.l[:2]).is_nary())
  54. def test_is_identifier(self):
  55. self.assertTrue(L('a').is_identifier())
  56. self.assertFalse(L(1).is_identifier())
  57. def test_is_int(self):
  58. self.assertTrue(L(1).is_int())
  59. self.assertFalse(L(1.5).is_int())
  60. self.assertFalse(L('a').is_int())
  61. def test_is_float(self):
  62. self.assertTrue(L(1.5).is_float())
  63. self.assertFalse(L(1).is_float())
  64. self.assertFalse(L('a').is_float())
  65. def test_is_numeric(self):
  66. self.assertTrue(L(1).is_numeric())
  67. self.assertTrue(L(1.5).is_numeric())
  68. self.assertFalse(L('a').is_numeric())
  69. def test_extract_polynome_properties_identifier(self):
  70. self.assertEqual(L('a').extract_polynome_properties(),
  71. (L(1), L('a'), L(1)))
  72. def test_extract_polynome_properties_None(self):
  73. self.assertIsNone(N('+').extract_polynome_properties())
  74. def test_extract_polynome_properties_power(self):
  75. power = N('^', L('a'), L(2))
  76. self.assertEqual(power.extract_polynome_properties(),
  77. (L(1), L('a'), L(2)))
  78. def test_extract_polynome_properties_coefficient_exponent_int(self):
  79. times = N('*', L(3), N('^', L('a'), L(2)))
  80. self.assertEqual(times.extract_polynome_properties(),
  81. (L(3), L('a'), L(2)))
  82. def test_extract_polynome_properties_coefficient_exponent_id(self):
  83. times = N('*', L(3), N('^', L('a'), L('b')))
  84. self.assertEqual(times.extract_polynome_properties(),
  85. (L(3), L('a'), L('b')))
  86. def test_get_scope_binary(self):
  87. plus = N('+', *self.l[:2])
  88. self.assertEqual(get_scope(plus), self.l[:2])
  89. def test_get_scope_nested_left(self):
  90. plus = N('+', N('+', *self.l[:2]), self.l[2])
  91. self.assertEqual(get_scope(plus), self.l[:3])
  92. def test_get_scope_nested_right(self):
  93. plus = N('+', self.l[0], N('+', *self.l[1:3]))
  94. self.assertEqual(get_scope(plus), self.l[:3])
  95. def test_get_scope_nested_deep(self):
  96. plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
  97. self.assertEqual(get_scope(plus), self.l)
  98. def test_get_scope_negation(self):
  99. root, a, b, c, d = tree('ab * -cd, a, b, -c, d')
  100. self.assertEqual(get_scope(root), [a, b, c, d])
  101. def test_get_scope_index(self):
  102. self.assertEqual(self.scope.index(self.a), 0)
  103. self.assertEqual(self.scope.index(self.b), 1)
  104. self.assertEqual(self.scope.index(self.cd), 2)
  105. def test_equals_node_leaf(self):
  106. a, b = plus = tree('a + b')
  107. self.assertFalse(a.equals(plus))
  108. self.assertFalse(plus.equals(a))
  109. def test_equals_other_op(self):
  110. plus, mul = tree('a + b, a * b')
  111. self.assertFalse(plus.equals(mul))
  112. def test_equals_add(self):
  113. p0, p1, p2, p3 = tree('a + b,a + b,b + a, a + c')
  114. self.assertTrue(p0.equals(p1))
  115. self.assertTrue(p0.equals(p2))
  116. self.assertFalse(p0.equals(p3))
  117. self.assertFalse(p2.equals(p3))
  118. def test_equals_mul(self):
  119. m0, m1, m2, m3 = tree('a * b,a * b,b * a, a * c')
  120. self.assertTrue(m0.equals(m1))
  121. self.assertTrue(m0.equals(m2))
  122. self.assertFalse(m0.equals(m3))
  123. self.assertFalse(m2.equals(m3))
  124. def test_equals_nary(self):
  125. p0, p1, p2, p3, p4 = \
  126. tree('a + b + c,a + c + b,b + a + c,b + c + a,a + b + d')
  127. self.assertTrue(p0.equals(p1))
  128. self.assertTrue(p0.equals(p2))
  129. self.assertTrue(p0.equals(p3))
  130. self.assertTrue(p1.equals(p2))
  131. self.assertTrue(p1.equals(p3))
  132. self.assertTrue(p2.equals(p3))
  133. self.assertFalse(p2.equals(p4))
  134. def test_equals_nary_mary(self):
  135. m0, m1 = tree('ab,2ab')
  136. self.assertFalse(m0.equals(m1))
  137. def test_equals_div(self):
  138. d0, d1, d2 = tree('a / b,a / b,b / a')
  139. self.assertTrue(d0.equals(d1))
  140. self.assertFalse(d0.equals(d2))
  141. def test_equals_neg(self):
  142. a0, a1 = tree('-a,a')
  143. self.assertFalse(a0.equals(a1))
  144. a0, a1 = tree('-a,-a')
  145. self.assertTrue(a0.equals(a1))
  146. m0, m1 = tree('-5 * -3,-5 * 6')
  147. self.assertFalse(m0.equals(m1))
  148. def test_equals_ignore_negation(self):
  149. p0, p1 = tree('-(a + b), a + b')
  150. self.assertTrue(p0.equals(p1, ignore_negation=True))
  151. a0, a1 = tree('-a,a')
  152. self.assertTrue(a0.equals(a1, ignore_negation=True))
  153. def test_scope___init__(self):
  154. self.assertEqual(self.scope.node, self.n)
  155. self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
  156. def test_scope_remove_leaf(self):
  157. self.scope.remove(self.b)
  158. self.assertEqual(self.scope.nodes, [self.a, self.cd])
  159. def test_scope_remove_node(self):
  160. self.scope.remove(self.cd)
  161. self.assertEqual(self.scope.nodes, [self.a, self.b])
  162. def test_scope_remove_error(self):
  163. self.assertRaises(ValueError, self.scope.remove, self.f)
  164. def test_scope_replace(self):
  165. self.scope.replace(self.cd, self.f)
  166. self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
  167. def test_nary_node(self):
  168. a, b, c, d = tree('a,b,c,d')
  169. self.assertEqualNodes(nary_node('+', [a]), a)
  170. self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
  171. self.assertEqualNodes(nary_node('+', [a, b, c]),
  172. N('+', N('+', a, b), c))
  173. self.assertEqualNodes(nary_node('+', [a, b, c, d]),
  174. N('+', N('+', N('+', a, b), c), d))
  175. def test_scope_as_nary_node(self):
  176. self.assertEqualNodes(self.scope.as_nary_node(), self.n)
  177. def test_scope_as_nary_node_negated(self):
  178. n = tree('-(a + b)')
  179. self.assertEqualNodes(Scope(n).as_nary_node(), n)
  180. self.assertEqualNodes(Scope(-n).as_nary_node(), -n)
  181. def test_contains(self):
  182. a, ab, bc, ln0, ln1, ma = tree('a, ab, bc, ln(a) + 1, ln(b) + 1, -a')
  183. self.assertTrue(a.contains(a))
  184. self.assertTrue(ab.contains(a))
  185. self.assertFalse(bc.contains(a))
  186. self.assertTrue(ln0.contains(a))
  187. self.assertFalse(ln1.contains(a))
  188. self.assertTrue(ma.contains(a))
  189. def test_construct_function_derivative(self):
  190. self.assertEqual(str(tree("(x ^ 2)'")), "[x ^ 2]'")
  191. self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''")
  192. self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2')
  193. def test_construct_function_integral(self):
  194. self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx')
  195. self.assertEqual(str(tree('int x ^ 2 dx')), 'int x ^ 2 dx')
  196. self.assertEqual(str(tree('int x ^ 2 dy')), 'int x ^ 2 dy')
  197. self.assertEqual(str(tree('int x ^ 2 dy')), 'int x ^ 2 dy')
  198. self.assertEqual(str(tree('int x + 1')), 'int x dx + 1')
  199. self.assertEqual(str(tree('int_a^b x ^ 2')), 'int_a^b x ^ 2 dx')
  200. self.assertEqual(str(tree('int_(a-b)^(a+b) x ^ 2')),
  201. 'int_(a - b)^(a + b) x ^ 2 dx')
  202. def test_construct_function_int_def(self):
  203. self.assertEqual(str(tree('[x ^ 2]_a^b')), '[x ^ 2]_a^b')
  204. self.assertEqual(str(tree('[x ^ 2]_(a-b)^(a+b)')),
  205. '[x ^ 2]_(a - b)^(a + b)')
  206. def test_construct_function_absolute_child(self):
  207. self.assertEqual(str(tree('ln(|x|)')), 'ln(|x|)')
  208. self.assertEqual(str(tree('sin(|x|)')), 'sin(|x|)')
  209. def test_construct_logarithm(self):
  210. self.assertEqual(str(tree('log n')), 'log(n)')
  211. self.assertEqual(str(tree('log(n)')), 'log(n)')
  212. self.assertEqual(str(tree('ln n')), 'ln(n)')
  213. self.assertEqual(str(tree('ln(n)')), 'ln(n)')
  214. self.assertEqual(str(tree('log_2 n')), 'log_2(n)')
  215. self.assertEqual(str(tree('log_2(n)')), 'log_2(n)')
  216. self.assertEqual(str(tree('log_g n')), 'log_g(n)')
  217. self.assertEqual(str(tree('log_(g + h) n')), 'log_(g + h)(n)')
  218. def test_infinity(self):
  219. self.assertEqual(infinity(), tree('oo'))
  220. def test_absolute(self):
  221. self.assertEqual(absolute(tree('x^2')), tree('|x^2|'))
  222. def test_sin(self):
  223. self.assertEqual(sin(tree('x')), tree('sin(x)'))
  224. def test_cos(self):
  225. self.assertEqual(cos(tree('x')), tree('cos(x)'))
  226. def test_tan(self):
  227. self.assertEqual(tan(tree('x')), tree('tan(x)'))
  228. def test_log(self):
  229. x = tree('x')
  230. self.assertEqual(log(x, 'e'), tree('ln x'))
  231. self.assertEqual(log(x, 2), tree('log_2 x'))
  232. self.assertEqual(log(x), tree('log x'))
  233. self.assertEqual(log(x, 10), tree('log x'))
  234. def test_ln(self):
  235. self.assertEqual(ln(tree('x')), tree('ln x'))
  236. def test_der(self):
  237. x2, x, y = tree('x ^ 2, x, y')
  238. self.assertEqual(der(x2), tree('[x ^ 2]\''))
  239. self.assertEqual(der(x2, x), tree('d/dx x ^ 2'))
  240. self.assertEqual(der(x2, y), tree('d/dy x ^ 2'))
  241. def test_integral(self):
  242. x2, x, y, a, b = tree('x ^ 2, x, y, a, b')
  243. self.assertEqual(integral(x2, x), tree('int x^2 dx'))
  244. self.assertEqual(integral(x2, x, a, b), tree('int_a^b x^2 dx'))
  245. self.assertEqual(integral(x2, y, a, b), tree('int_a^b x^2 dy'))
  246. def test_int_def(self):
  247. x2, a, b, expect = tree('x ^ 2, a, b, [x ^ 2]_a^b')
  248. self.assertEqual(int_def(x2, a, b), expect)
  249. def test_eq(self):
  250. x, a, b, expect = tree('x, a, b, x + a = b')
  251. self.assertEqual(eq(x + a, b), expect)