Преглед изворни кода

Merge branch 'master' of kompiler.org:trs

Sander Mathijs van Veen пре 14 година
родитељ
комит
0c9c2995d5
3 измењених фајлова са 82 додато и 17 уклоњено
  1. 1 1
      external/pybison
  2. 36 13
      src/rules/factors.py
  3. 45 3
      tests/test_rules_factors.py

+ 1 - 1
external/pybison

@@ -1 +1 @@
-Subproject commit b4fd7ccf01d7030c3d6207c1ce2ff6bdbb8cad55
+Subproject commit 5f74eb1a7f356d9fcfe05a487e2ac2e1db0794b8

+ 36 - 13
src/rules/factors.py

@@ -1,31 +1,35 @@
+from itertools import product, combinations
+
+from .utils import nary_node
 from ..node import OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
-from .utils import nary_node
 
 
 def match_expand(node):
     """
     a * (b + c) -> ab + ac
+    (b + c) * a -> ab + ac
+    (a + b) * (c + d) -> ac + ad + bc + bd
     """
     assert node.is_op(OP_MUL)
 
-    # TODO: fix!
-    return []
+    scope = node.get_scope()
 
     p = []
-    a = []
-    bc = []
+    leaves = []
+    additions = []
 
     for n in node.get_scope():
         if n.is_leaf():
-            a.append(n)
+            leaves.append(n)
         elif n.op == OP_ADD:
-            bc.append(n)
+            additions.append(n)
+
+    for args in product(leaves, additions):
+        p.append(P(node, expand_single, args))
 
-    if a and bc:
-        for a_node in a:
-            for bc_node in bc:
-                p.append(P(node, expand_single, a_node, bc_node))
+    for args in combinations(additions, 2):
+        p.append(P(node, expand_double, args))
 
     return p
 
@@ -35,7 +39,8 @@ def expand_single(root, args):
     Combine a leaf (a) multiplied with an addition of two expressions
     (b + c) to an addition of two multiplications.
 
-    >>> a * (b + c) -> a * b + a * c
+    a * (b + c) -> ab + ac
+    (b + c) * a -> ab + ac
     """
     a, bc = args
     b, c = bc
@@ -44,7 +49,25 @@ def expand_single(root, args):
     # Replace 'a' with the new expression
     scope[scope.index(a)] = a * b + a * c
 
-    # Remove the old addition
+    # Remove the addition
     scope.remove(bc)
 
     return nary_node('*', scope)
+
+
+def expand_double(root, args):
+    """
+    Rewrite two multiplied additions to an addition of four multiplications.
+
+    (a + b) * (c + d) -> ac + ad + bc + bd
+    """
+    (a, b), (c, d) = ab, cd = args
+    scope = root.get_scope()
+
+    # Replace 'b + c' with the new expression
+    scope[scope.index(ab)] = a * c + a * d + b * c + b * d
+
+    # Remove the right addition
+    scope.remove(cd)
+
+    return nary_node('*', scope)

+ 45 - 3
tests/test_rules_factors.py

@@ -1,4 +1,4 @@
-from src.rules.factors import match_expand, expand_single
+from src.rules.factors import match_expand, expand_single, expand_double
 from src.possibilities import Possibility as P
 from tests.rulestestcase import RulesTestCase
 from tests.test_rules_poly import tree
@@ -7,7 +7,49 @@ from tests.test_rules_poly import tree
 class TestRulesFactors(RulesTestCase):
 
     def test_match_expand(self):
-        pass
+        a, bc, d = tree('a,b + c,d')
+        b, c = bc
+
+        root = a * bc
+        possibilities = match_expand(root)
+        self.assertEqualPos(possibilities,
+                [P(root, expand_single, (a, bc))])
+
+        root = bc * a
+        possibilities = match_expand(root)
+        self.assertEqualPos(possibilities,
+                [P(root, expand_single, (a, bc))])
+
+        root = a * d * bc
+        possibilities = match_expand(root)
+        self.assertEqualPos(possibilities,
+                [P(root, expand_single, (a, bc)),
+                 P(root, expand_single, (d, bc))])
+
+        ab, cd = root = (a + b) * (c + d)
+        possibilities = match_expand(root)
+        self.assertEqualPos(possibilities,
+                [P(root, expand_double, (ab, cd))])
 
     def test_expand_single(self):
-        pass
+        a, b, c, d = tree('a,b,c,d')
+        bc = b + c
+
+        root = a * bc
+        self.assertEqualNodes(expand_single(root, (a, bc)),
+                              a * b + a * c)
+
+        root = a * d * bc
+        self.assertEqualNodes(expand_single(root, (a, bc)),
+                              (a * b + a * c) * d)
+
+    def test_expand_double(self):
+        (a, b), (c, d) = ab, cd = tree('a + b,c + d')
+
+        root = ab * cd
+        self.assertEqualNodes(expand_double(root, (ab, cd)),
+                              a * c + a * d + b * c + b * d)
+
+        root = a * ab * b * cd * c
+        self.assertEqualNodes(expand_double(root, (ab, cd)),
+                              a * (a * c + a * d + b * c + b * d) * b * c)