Commit 0c9c2995 authored by Sander Mathijs van Veen's avatar Sander Mathijs van Veen

Merge branch 'master' of kompiler.org:trs

parents 150add57 2f75e399
pybison @ 5f74eb1a
Subproject commit b4fd7ccf01d7030c3d6207c1ce2ff6bdbb8cad55 Subproject commit 5f74eb1a7f356d9fcfe05a487e2ac2e1db0794b8
from itertools import product, combinations
from .utils import nary_node
from ..node import OP_ADD, OP_MUL from ..node import OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
def match_expand(node): def match_expand(node):
""" """
a * (b + c) -> ab + ac a * (b + c) -> ab + ac
(b + c) * a -> ab + ac
(a + b) * (c + d) -> ac + ad + bc + bd
""" """
assert node.is_op(OP_MUL) assert node.is_op(OP_MUL)
# TODO: fix! scope = node.get_scope()
return []
p = [] p = []
a = [] leaves = []
bc = [] additions = []
for n in node.get_scope(): for n in node.get_scope():
if n.is_leaf(): if n.is_leaf():
a.append(n) leaves.append(n)
elif n.op == OP_ADD: elif n.op == OP_ADD:
bc.append(n) additions.append(n)
for args in product(leaves, additions):
p.append(P(node, expand_single, args))
if a and bc: for args in combinations(additions, 2):
for a_node in a: p.append(P(node, expand_double, args))
for bc_node in bc:
p.append(P(node, expand_single, a_node, bc_node))
return p return p
...@@ -35,7 +39,8 @@ def expand_single(root, args): ...@@ -35,7 +39,8 @@ def expand_single(root, args):
Combine a leaf (a) multiplied with an addition of two expressions Combine a leaf (a) multiplied with an addition of two expressions
(b + c) to an addition of two multiplications. (b + c) to an addition of two multiplications.
>>> a * (b + c) -> a * b + a * c a * (b + c) -> ab + ac
(b + c) * a -> ab + ac
""" """
a, bc = args a, bc = args
b, c = bc b, c = bc
...@@ -44,7 +49,25 @@ def expand_single(root, args): ...@@ -44,7 +49,25 @@ def expand_single(root, args):
# Replace 'a' with the new expression # Replace 'a' with the new expression
scope[scope.index(a)] = a * b + a * c scope[scope.index(a)] = a * b + a * c
# Remove the old addition # Remove the addition
scope.remove(bc) scope.remove(bc)
return nary_node('*', scope) return nary_node('*', scope)
def expand_double(root, args):
"""
Rewrite two multiplied additions to an addition of four multiplications.
(a + b) * (c + d) -> ac + ad + bc + bd
"""
(a, b), (c, d) = ab, cd = args
scope = root.get_scope()
# Replace 'b + c' with the new expression
scope[scope.index(ab)] = a * c + a * d + b * c + b * d
# Remove the right addition
scope.remove(cd)
return nary_node('*', scope)
from src.rules.factors import match_expand, expand_single from src.rules.factors import match_expand, expand_single, expand_double
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase from tests.rulestestcase import RulesTestCase
from tests.test_rules_poly import tree from tests.test_rules_poly import tree
...@@ -7,7 +7,49 @@ from tests.test_rules_poly import tree ...@@ -7,7 +7,49 @@ from tests.test_rules_poly import tree
class TestRulesFactors(RulesTestCase): class TestRulesFactors(RulesTestCase):
def test_match_expand(self): def test_match_expand(self):
pass a, bc, d = tree('a,b + c,d')
b, c = bc
root = a * bc
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc))])
root = bc * a
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc))])
root = a * d * bc
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc)),
P(root, expand_single, (d, bc))])
ab, cd = root = (a + b) * (c + d)
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_double, (ab, cd))])
def test_expand_single(self): def test_expand_single(self):
pass a, b, c, d = tree('a,b,c,d')
bc = b + c
root = a * bc
self.assertEqualNodes(expand_single(root, (a, bc)),
a * b + a * c)
root = a * d * bc
self.assertEqualNodes(expand_single(root, (a, bc)),
(a * b + a * c) * d)
def test_expand_double(self):
(a, b), (c, d) = ab, cd = tree('a + b,c + d')
root = ab * cd
self.assertEqualNodes(expand_double(root, (ab, cd)),
a * c + a * d + b * c + b * d)
root = a * ab * b * cd * c
self.assertEqualNodes(expand_double(root, (ab, cd)),
a * (a * c + a * d + b * c + b * d) * b * c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment