Преглед изворни кода

Added numeric rules in separate file.

- Moved combine_numerics to new rules file and renamed it to add_numerics.
- Added (match_)divide_numerics that checks if numbers can be divided without
  losing precision.
Taddeus Kroes пре 14 година
родитељ
комит
6843d12e84
2 измењених фајлова са 15 додато и 45 уклоњено
  1. 5 26
      src/rules/poly.py
  2. 10 19
      tests/test_rules_poly.py

+ 5 - 26
src/rules/poly.py

@@ -1,17 +1,17 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, \
-        TYPE_OPERATOR, OP_ADD, OP_MUL
+        OP_ADD, OP_MUL
 from ..possibilities import Possibility as P
 from .utils import nary_node
+from .numerics import add_numerics
 
 
 def match_expand(node):
     """
     a * (b + c) -> ab + ac
     """
-    assert node.type == TYPE_OPERATOR
-    assert node.op == OP_MUL
+    assert node.is_op(OP_MUL)
 
     # TODO: fix!
     return []
@@ -60,8 +60,7 @@ def match_combine_polynomes(node, verbose=False):
     n + exp + m -> exp + (n + m)
     k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
     """
-    assert node.type == TYPE_OPERATOR
-    assert node.op == OP_ADD
+    assert node.is_op(OP_ADD)
 
     p = []
 
@@ -101,7 +100,7 @@ def match_combine_polynomes(node, verbose=False):
             if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \
                     and r0.is_numeric() and r1.is_numeric():
                 # 2 + 3 -> 5
-                p.append(P(node, combine_numerics, (n0, n1, r0.value, r1.value)))
+                p.append(P(node, add_numerics, (n0, n1, r0.value, r1.value)))
             elif c0.is_numeric() and c1.is_numeric() and r0 == r1 and e0 == e1:
                 # 2a + 2a -> 4a
                 # a + 2a -> 3a
@@ -112,26 +111,6 @@ def match_combine_polynomes(node, verbose=False):
     return p
 
 
-def combine_numerics(root, args):
-    """
-    Combine two constants to a single constant in an n-ary plus.
-
-    Synopsis:
-    c0 + c1 -> eval(c1 + c2)
-    """
-    n0, n1, c0, c1 = args
-
-    scope = root.get_scope()
-
-    # Replace the left node with the new expression
-    scope[scope.index(n0)] = Leaf(c0 + c1)
-
-    # Remove the right node
-    scope.remove(n1)
-
-    return nary_node('+', scope)
-
-
 def combine_polynomes(root, args):
     """
     Combine two multiplications of any polynome in an n-ary plus.

+ 10 - 19
tests/test_rules_poly.py

@@ -1,5 +1,5 @@
-from src.rules.poly import match_combine_polynomes, combine_polynomes, \
-        combine_numerics
+from src.rules.poly import match_combine_polynomes, combine_polynomes
+from src.rules.numerics import add_numerics
 from src.possibilities import Possibility as P
 from src.node import ExpressionLeaf as L
 from src.parser import Parser
@@ -76,34 +76,25 @@ class TestRulesPoly(RulesTestCase):
         #self.assertEqualPos(possibilities,
         #        [P(root, combine_polynomes, (left, right, c, c, a_b, d))])
 
-    def test_match_combine_numerics(self):
+    def test_match_add_numerics(self):
         l0, l1, l2 = tree('0,1,2')
         root = l0 + l1 + l2
 
         possibilities = match_combine_polynomes(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_numerics, (l0, l1, l0, l1)),
-                 P(root, combine_numerics, (l0, l2, l0, l2)),
-                 P(root, combine_numerics, (l1, l2, l1, l2))])
+                [P(root, add_numerics, (l0, l1, l0, l1)),
+                 P(root, add_numerics, (l0, l2, l0, l2)),
+                 P(root, add_numerics, (l1, l2, l1, l2))])
 
-    def test_match_combine_numerics_explicit_powers(self):
+    def test_match_add_numerics_explicit_powers(self):
         l0, l1, l2 = tree('0^1,1*1,1*2^1')
         root = l0 + l1 + l2
 
         possibilities = match_combine_polynomes(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_numerics, (l0, l1, l0[0], l1[1])),
-                 P(root, combine_numerics, (l0, l2, l0[0], l2[1][0])),
-                 P(root, combine_numerics, (l1, l2, l1[1], l2[1][0]))])
-
-    def test_combine_numerics(self):
-        l0, l1 = tree('1,2')
-        self.assertEqual(combine_numerics(l0 + l1, (l0, l1, 1, 2)), 3)
-
-    def test_combine_numerics_nary(self):
-        l0, a, l1 = tree('1,a,2')
-        self.assertEqual(combine_numerics(l0 + a + l1, (l0, l1, 1, 2)),
-                         L(3) + a)
+                [P(root, add_numerics, (l0, l1, l0[0], l1[1])),
+                 P(root, add_numerics, (l0, l2, l0[0], l2[1][0])),
+                 P(root, add_numerics, (l1, l2, l1[1], l2[1][0]))])
 
     def test_combine_polynomes(self):
         # 2a + 3a -> (2 + 3) * a