Commit 2f75e399 authored by Taddeus Kroes's avatar Taddeus Kroes

Implemented and tested factor decomposition (expand).

parent 4dd2dad4
from itertools import product, combinations
from .utils import nary_node
from ..node import OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
def match_expand(node):
"""
a * (b + c) -> ab + ac
(b + c) * a -> ab + ac
(a + b) * (c + d) -> ac + ad + bc + bd
"""
assert node.is_op(OP_MUL)
# TODO: fix!
return []
scope = node.get_scope()
p = []
a = []
bc = []
leaves = []
additions = []
for n in node.get_scope():
if n.is_leaf():
a.append(n)
leaves.append(n)
elif n.op == OP_ADD:
bc.append(n)
additions.append(n)
for args in product(leaves, additions):
p.append(P(node, expand_single, args))
if a and bc:
for a_node in a:
for bc_node in bc:
p.append(P(node, expand_single, a_node, bc_node))
for args in combinations(additions, 2):
p.append(P(node, expand_double, args))
return p
......@@ -35,7 +39,8 @@ def expand_single(root, args):
Combine a leaf (a) multiplied with an addition of two expressions
(b + c) to an addition of two multiplications.
>>> a * (b + c) -> a * b + a * c
a * (b + c) -> ab + ac
(b + c) * a -> ab + ac
"""
a, bc = args
b, c = bc
......@@ -44,7 +49,25 @@ def expand_single(root, args):
# Replace 'a' with the new expression
scope[scope.index(a)] = a * b + a * c
# Remove the old addition
# Remove the addition
scope.remove(bc)
return nary_node('*', scope)
def expand_double(root, args):
"""
Rewrite two multiplied additions to an addition of four multiplications.
(a + b) * (c + d) -> ac + ad + bc + bd
"""
(a, b), (c, d) = ab, cd = args
scope = root.get_scope()
# Replace 'b + c' with the new expression
scope[scope.index(ab)] = a * c + a * d + b * c + b * d
# Remove the right addition
scope.remove(cd)
return nary_node('*', scope)
from src.rules.factors import match_expand, expand_single
from src.rules.factors import match_expand, expand_single, expand_double
from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase
from tests.test_rules_poly import tree
......@@ -7,7 +7,49 @@ from tests.test_rules_poly import tree
class TestRulesFactors(RulesTestCase):
def test_match_expand(self):
pass
a, bc, d = tree('a,b + c,d')
b, c = bc
root = a * bc
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc))])
root = bc * a
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc))])
root = a * d * bc
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_single, (a, bc)),
P(root, expand_single, (d, bc))])
ab, cd = root = (a + b) * (c + d)
possibilities = match_expand(root)
self.assertEqualPos(possibilities,
[P(root, expand_double, (ab, cd))])
def test_expand_single(self):
pass
a, b, c, d = tree('a,b,c,d')
bc = b + c
root = a * bc
self.assertEqualNodes(expand_single(root, (a, bc)),
a * b + a * c)
root = a * d * bc
self.assertEqualNodes(expand_single(root, (a, bc)),
(a * b + a * c) * d)
def test_expand_double(self):
(a, b), (c, d) = ab, cd = tree('a + b,c + d')
root = ab * cd
self.assertEqualNodes(expand_double(root, (ab, cd)),
a * c + a * d + b * c + b * d)
root = a * ab * b * cd * c
self.assertEqualNodes(expand_double(root, (ab, cd)),
a * (a * c + a * d + b * c + b * d) * b * c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment