Procházet zdrojové kódy

Added tests for polynome combinations.

Sander Mathijs van Veen před 14 roky
rodič
revize
7615831379
2 změnil soubory, kde provedl 80 přidání a 28 odebrání
  1. 17 8
      src/rules/poly.py
  2. 63 20
      tests/test_rules_poly.py

+ 17 - 8
src/rules/poly.py

@@ -90,23 +90,24 @@ def match_combine_polynomes(node, verbose=False):
     # Each combination of powers of the same value and polynome can be added
     if len(polys) >= 2:
         for left, right in combinations(polys, 2):
-            c0, r0, e0 = left[1]
-            c1, r1, e1 = right[1]
+            n0, p0 = left
+            n1, p1 = right
+            c0, r0, e0 = p0
+            c1, r1, e1 = p1
 
             # Both numeric root and same exponent -> combine coefficients and
             # roots, or: same root and exponent -> combine coefficients.
+            # TODO: Addition with zero, e.g. a + 0 -> a
             if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \
                     and r0.is_numeric() and r1.is_numeric():
                 # 2 + 3 -> 5
-                p.append(P(node, combine_numerics, \
-                           (left[0], right[0], r0, r1)))
+                p.append(P(node, combine_numerics, (n0, n1, r0.value, r1.value)))
             elif c0.is_numeric() and c1.is_numeric() and r0 == r1 and e0 == e1:
                 # 2a + 2a -> 4a
                 # a + 2a -> 3a
                 # 2a + a -> 3a
                 # a + a -> 2a
-                p.append(P(node, combine_polynomes, \
-                           (left[0], right[0], c0, c1, r0, e0)))
+                p.append(P(node, combine_polynomes, (n0, n1, c0, c1, r0, e0)))
 
     return p
 
@@ -118,9 +119,17 @@ def combine_numerics(root, args):
     Synopsis:
     c0 + c1 -> eval(c1 + c2)
     """
-    c0, c1 = args
+    n0, n1, c0, c1 = args
 
-    return Leaf(c0.value + c1.value)
+    scope = root.get_scope()
+
+    # Replace the left node with the new expression
+    scope[scope.index(n0)] = Leaf(c0 + c1)
+
+    # Remove the right node
+    scope.remove(n1)
+
+    return nary_node('+', scope)
 
 
 def combine_polynomes(root, args):

+ 63 - 20
tests/test_rules_poly.py

@@ -3,6 +3,7 @@ import unittest
 from src.rules.poly import match_combine_polynomes, combine_polynomes, \
         combine_numerics
 from src.possibilities import Possibility as P
+from src.node import ExpressionNode, ExpressionLeaf as L
 from src.parser import Parser
 from tests.parser import ParserWrapper
 
@@ -24,15 +25,6 @@ class TestRulesPoly(unittest.TestCase):
 
             self.assertEqual(p, e)
 
-    def test_numbers(self):
-        return
-        # TODO: Move to combine numeric test
-        l1, l2 = root = tree('1+2')
-        possibilities = match_combine_polynomes(root)
-        self.assertEqualPos(possibilities,
-                [P(root, combine_numerics, ((l1, (l1, l1, l1)),
-                                             (l2, (l1, l2, l1))))])
-
     def test_identifiers_basic(self):
         a1, a2 = root = tree('a+a')
         possibilities = match_combine_polynomes(root)
@@ -96,16 +88,67 @@ class TestRulesPoly(unittest.TestCase):
         #self.assertEqualPos(possibilities,
         #        [P(root, combine_polynomes, (left, right, c, c, a_b, d))])
 
-    def test_match_combine_polynomes_numeric_combinations(self):
-        return
-        root = tree('0+1+2')
-        # TODO: this test fails with this code: l0, l1, l2 = tree('0,1,2')
-        l0, l1, l2 = root[0][0], root[0][1], root[1]
+    def test_match_combine_numerics(self):
+        l0, l1, l2 = tree('0,1,2')
+        root = l0 + l1 + l2
+
+        possibilities = match_combine_polynomes(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_numerics, (l0, l1, l0, l1)),
+                 P(root, combine_numerics, (l0, l2, l0, l2)),
+                 P(root, combine_numerics, (l1, l2, l1, l2))])
+
+    def test_match_combine_numerics_explicit_powers(self):
+        l0, l1, l2 = tree('0^1,1*1,1*2^1')
+        root = l0 + l1 + l2
+
         possibilities = match_combine_polynomes(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_polynomes, ((l0, (l1, l0, l1)),
-                                             (l1, (l1, l1, l1)))),
-                 P(root, combine_polynomes, ((l0, (l1, l0, l1)),
-                                             (l2, (l1, l2, l1)))),
-                 P(root, combine_polynomes, ((l1, (l1, l1, l1)),
-                                             (l2, (l1, l2, l1))))])
+                [P(root, combine_numerics, (l0, l1, l0[0], l1[1])),
+                 P(root, combine_numerics, (l0, l2, l0[0], l2[1][0])),
+                 P(root, combine_numerics, (l1, l2, l1[1], l2[1][0]))])
+
+    def test_combine_numerics(self):
+        l0, l1 = tree('1,2')
+        self.assertEqual(combine_numerics(l0 + l1, (l0, l1, 1, 2)), 3)
+
+    def test_combine_numerics_nary(self):
+        l0, a, l1 = tree('1,a,2')
+        self.assertEqual(combine_numerics(l0 + a + l1, (l0, l1, 1, 2)),
+                         L(3) + a)
+
+    def test_combine_polynomes(self):
+        # 2a + 3a -> (2 + 3) * a
+        l0, a, l1, l2 = tree('2,a,3,1')
+        root = l0 * a + l1 * a
+        left, right = root
+        replacement = combine_polynomes(root, (left, right, l0, l1, a, 1))
+        self.assertEqualNodes(replacement, (l0 + l1) * a)
+
+        # a + 3a -> (1 + 3) * a
+        root = a + l1 * a
+        left, right = root
+        replacement = combine_polynomes(root, (left, right, l2, l1, a, 1))
+        self.assertEqualNodes(replacement, (l2 + l1) * a)
+
+        # 2a + a -> (2 + 1) * a
+        root = l0 * a + a
+        left, right = root
+        replacement = combine_polynomes(root, (left, right, l0, l2, a, 1))
+        self.assertEqualNodes(replacement, (l0 + 1) * a)
+
+        # a + a -> (1 + 1) * a
+        root = a + a
+        left, right = root
+        replacement = combine_polynomes(root, (left, right, l2, l2, a, 1))
+        self.assertEqualNodes(replacement, (l2 + 1) * a)
+
+    def assertEqualNodes(self, a, b):
+        if not isinstance(a, ExpressionNode):
+            return self.assertEqual(a, b)
+
+        self.assertIsInstance(b, ExpressionNode)
+        self.assertEqual(a.op, b.op)
+
+        for ca, cb in zip(a, b):
+            self.assertEqualNodes(ca, cb)