test_node.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import unittest
  2. from src.node import ExpressionNode as N, ExpressionLeaf as L
  3. class TestNode(unittest.TestCase):
  4. def setUp(self):
  5. self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
  6. def test_replace_node(self):
  7. inner = N('+', L(1), L(2))
  8. node = N('+', inner, L(3))
  9. replacement = N('-', L(4), L(5))
  10. inner.replace(replacement)
  11. self.assertEqual(str(node), '4 - 5 + 3')
  12. def test_replace_leaf(self):
  13. inner = N('+', L(1), L(2))
  14. node = N('+', inner, L(3))
  15. replacement = L(4)
  16. inner.replace(replacement)
  17. self.assertEqual(str(node), '4 + 3')
  18. def test_is_power_true(self):
  19. self.assertTrue(N('^', *self.l[:2]).is_power())
  20. self.assertFalse(N('+', *self.l[:2]).is_power())
  21. def test_is_nary(self):
  22. self.assertTrue(N('+', *self.l[:2]).is_nary())
  23. self.assertTrue(N('-', *self.l[:2]).is_nary())
  24. self.assertTrue(N('*', *self.l[:2]).is_nary())
  25. self.assertFalse(N('^', *self.l[:2]).is_nary())
  26. def test_is_identifier(self):
  27. self.assertTrue(L('a').is_identifier())
  28. self.assertFalse(L(1).is_identifier())
  29. def test_is_int(self):
  30. self.assertTrue(L(1).is_int())
  31. self.assertFalse(L(1.5).is_int())
  32. self.assertFalse(L('a').is_int())
  33. def test_is_float(self):
  34. self.assertTrue(L(1.5).is_float())
  35. self.assertFalse(L(1).is_float())
  36. self.assertFalse(L('a').is_float())
  37. def test_is_numeric(self):
  38. self.assertTrue(L(1).is_numeric())
  39. self.assertTrue(L(1.5).is_numeric())
  40. self.assertFalse(L('a').is_numeric())
  41. def test_get_order_identifier(self):
  42. self.assertEqual(L('a').get_order(), ('a', 1, 1))
  43. def test_get_order_None(self):
  44. self.assertIsNone(L(1).get_order())
  45. def test_get_order_power(self):
  46. power = N('^', L('a'), L(2))
  47. self.assertEqual(power.get_order(), ('a', 2, 1))
  48. def test_get_order_coefficient_exponent_int(self):
  49. times = N('*', L(3), N('^', L('a'), L(2)))
  50. self.assertEqual(times.get_order(), ('a', 2, 3))
  51. def test_get_order_coefficient_exponent_id(self):
  52. times = N('*', L(3), N('^', L('a'), L('b')))
  53. self.assertEqual(times.get_order(), ('a', 'b', 3))
  54. def test_get_scope_binary(self):
  55. plus = N('+', *self.l[:2])
  56. self.assertEqual(plus.get_scope(), self.l[:2])
  57. def test_get_scope_nested_left(self):
  58. plus = N('+', N('+', *self.l[:2]), self.l[2])
  59. self.assertEqual(plus.get_scope(), self.l[:3])
  60. def test_get_scope_nested_right(self):
  61. plus = N('+', self.l[0], N('+', *self.l[1:3]))
  62. self.assertEqual(plus.get_scope(), self.l[:3])
  63. def test_get_scope_nested_deep(self):
  64. plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
  65. self.assertEqual(plus.get_scope(), self.l)