Browse Source

Restructured rules and added match_expand.

Sander Mathijs van Veen 14 years ago
parent
commit
6c6d35a6b9
4 changed files with 85 additions and 51 deletions
  1. 3 3
      src/rules/__init__.py
  2. 77 43
      src/rules/poly.py
  3. 1 1
      tests/test_b1_ch08.py
  4. 4 4
      tests/test_rules.py

+ 3 - 3
src/rules/__init__.py

@@ -1,8 +1,8 @@
-from ..node import ExpressionNode as Node, OP_ADD
-from .poly import match_combine_factors#, match_combine_parentheses
+from ..node import ExpressionNode as Node, OP_ADD, OP_MUL
+from .poly import match_combine_factors, match_expand
 
 
 RULES = {
         OP_ADD: [match_combine_factors],
-        #OP_MUL: [match_combine_parentheses],
+        OP_MUL: [match_expand],
         }

+ 77 - 43
src/rules/poly.py

@@ -1,57 +1,91 @@
 from itertools import combinations
 
-from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR
+from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD
 from ..possibilities import Possibility as P
 
+
+def match_expand(node):
+    """
+    a * (b + c) -> ab + ac
+    """
+    if node.type != TYPE_OPERATOR or not node.op & OP_MUL:
+        return []
+
+    p = []
+
+    # 'a' parts
+    left = []
+
+    # '(b + c)' parts
+    right = []
+
+    for n in node.get_scope():
+        if node.type == TYPE_OPERATOR:
+            if n.op & OP_ADD:
+                right.append(n)
+        else:
+            left.append(n)
+
+    if len(left) and len(right):
+        for l in left:
+            for r in right:
+                p.append(P(node, expand_single, l, r))
+
+    return p
+
+def expand_single(root, args):
+    """
+    Combine a leaf (left) multiplied with an addition of two expressions
+    (right) to an addition of two multiplications.
+
+    >>> a * (b + c) -> a * b + a * c
+    """
+    left, right = args
+    others = list(set(root.get_scope()) - {left, right})
+
 def match_combine_factors(node):
     """
     n + exp + m -> exp + (n + m)
     k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
     """
-    if node.type != TYPE_OPERATOR:
+    if node.type != TYPE_OPERATOR or not node.op & OP_ADD:
         return []
 
     p = []
 
-    if node.is_nary():
-        # Collect all nodes that can be combined
-        # Numeric leaves
-        numerics = []
-
-        # Identifier leaves of all orders, tuple format is;
-        # (identifier, exponent, coefficient)
-        orders = []
-
-        # Nodes that cannot be combined
-        others = []
-
-        for n in node.get_scope():
-            if isinstance(n, Leaf):
-                if n.is_numeric():
-                    numerics.append(n)
-                elif n.is_identifier():
-                    orders.append((n.value, 1, 1))
-            else:
-                order = n.get_order()
-
-                if order:
-                    orders += order
-                else:
-                    others.append(n)
-
-        if len(numerics) > 1:
-            for num0, num1 in combinations(numerics, 2):
-                p.append(P(node, combine_numerics, (num0, num1, others)))
-
-        if len(orders) > 1:
-            for order0, order1 in combinations(orders, 2):
-                id0, exponent0, coeff0 = order0
-                id1, exponent1, coeff1 = order1
-
-                if id0 == id1 and exponent0 == exponent1:
-                    # Same identifier and exponent -> combine coefficients
-                    args = order0 + (coeff1,) + (others,)
-                    p.append(P(node, combine_orders, args))
+    # Collect all nodes that can be combined
+    # Numeric leaves
+    numerics = []
+
+    # Identifier leaves of all orders, tuple format is;
+    # (identifier, exponent, coefficient)
+    orders = []
+
+    for n in node.get_scope():
+        if node.type == TYPE_OPERATOR:
+            order = n.get_order()
+
+            if order:
+                orders += order
+        else:
+            if n.is_numeric():
+                numerics.append(n)
+            elif n.is_identifier():
+                orders.append((n.value, 1, 1))
+
+    if len(numerics) > 1:
+        for num0, num1 in combinations(numerics, 2):
+            p.append(P(node, combine_numerics, (num0, num1)))
+
+    if len(orders) > 1:
+        for order0, order1 in combinations(orders, 2):
+            id0, exponent0, coeff0 = order0
+            id1, exponent1, coeff1 = order1
+
+            if id0 == id1 and exponent0 == exponent1:
+                # Same identifier and exponent -> combine coefficients
+                args = order0 + (coeff1,)
+                p.append(P(node, combine_orders, args))
 
     return p
 
@@ -63,8 +97,8 @@ def combine_numerics(root, args):
     Example:
     >>> 3 + 4 -> 7
     """
-    numerics, others = args
-    value = sum([n.value for n in numerics])
+    others = list(set(root.get_scope()) - set(args))
+    value = sum([n.value for n in args])
 
     return nary_node('+', others + [Leaf(value)])
 

+ 1 - 1
tests/test_b1_ch8.py → tests/test_b1_ch08.py

@@ -5,7 +5,7 @@ from src.node import ExpressionNode as N, ExpressionLeaf as L
 from tests.parser import run_expressions
 
 
-class TestB1Ch8(unittest.TestCase):
+class TestB1Ch08(unittest.TestCase):
 
     def test_diagnostic_test(self):
         run_expressions(Parser, [

+ 4 - 4
tests/test_rules.py

@@ -23,15 +23,15 @@ class TestRules(unittest.TestCase):
         l0, l1 = L(1), L(2)
         plus = N('+', l0, l1)
         p = match_combine_factors(plus)
-        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, []))])
+        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
 
     def test_match_combine_factors_numeric_combinations(self):
         l0, l1, l2 = L(1), L(2), L(2)
         plus = N('+', N('+', l0, l1), l2)
         p = match_combine_factors(plus)
-        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, [])),
-                                P(plus, combine_numerics, (l0, l2, [])),
-                                P(plus, combine_numerics, (l1, l2, []))])
+        self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
+                                P(plus, combine_numerics, (l0, l2)),
+                                P(plus, combine_numerics, (l1, l2))])
 
     def assertEqualPos(self, possibilities, expected):
         for p, e in zip(possibilities, expected):