Commit 6f628ddd authored by Taddeus Kroes's avatar Taddeus Kroes

Added group combination to rules.

parent 95372310
from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW
from .poly import match_combine_polynomes
from .groups import match_combine_groups
from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \
match_multiply_exponents, match_duplicate_exponent, \
......@@ -10,7 +11,8 @@ from .fractions import match_constant_division, match_add_constant_fractions, \
RULES = {
OP_ADD: [match_add_constant_fractions, match_combine_polynomes],
OP_ADD: [match_add_constant_fractions, match_combine_groups, \
match_combine_polynomes],
OP_MUL: [match_expand, match_add_exponents, \
match_expand_and_add_fractions],
OP_DIV: [match_subtract_exponents, match_divide_numerics, \
......
from itertools import combinations
from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
ExpressionLeaf as Leaf
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
def match_combine_groups(node):
"""
Match possible combinations of groups of expressions using non-strict
equivalence.
Examples:
a + a -> 2a
a + 2a -> 3a
ab + ab -> 2ab
ab + 2ab -> 3ab
ab + ba -> 2ab
"""
assert node.is_op(OP_ADD)
p = []
groups = []
for n in node.get_scope():
groups.append((1, n, n))
# Each number multiplication yields a group, multiple occurences of
# the same group can be replaced by a single one
if n.is_op(OP_MUL):
scope = n.get_scope()
l = len(scope)
for i, sub_node in enumerate(scope):
if sub_node.is_numeric():
others = [scope[j] for j in range(i) + range(i + 1, l)]
g = others[0] if len(others) == 1 else Node('*', *others)
groups.append((sub_node, g, n))
for g0, g1 in combinations(groups, 2):
if g0[1].equals(g1[1]):
p.append(P(node, combine_groups, g0 + g1))
return p
def combine_groups(root, args):
c0, g0, n0, c1, g1, n1 = args
scope = root.get_scope()
if not isinstance(c0, Leaf):
c0 = Leaf(c0)
# Replace the left node with the new expression
scope[scope.index(n0)] = (c0 + c1) * g0
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
from src.rules.groups import match_combine_groups
from src.rules.groups import match_combine_groups, combine_groups
from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase, tree
class TestRulesGroups(RulesTestCase):
def test_(self):
pass
def test_match_combine_groups_no_const(self):
a0, a1 = root = tree('a + a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, a0, a0, 1, a1, a1))])
def test_match_combine_groups_single_const(self):
a0, mul = root = tree('a + 2a')
l2, a1 = mul
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, a0, a0, l2, a1, mul))])
def test_match_combine_groups_two_const(self):
((l2, a0), b), (l3, a1) = (m0, b), m1 = root = tree('2a + b + 3a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (l2, a0, m0, l3, a1, m1))])
def test_match_combine_groups_n_const(self):
((l2, a0), (l3, a1)), (l4, a2) = (m0, m1), m2 = root = tree('2a+3a+4a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (l2, a0, m0, l3, a1, m1)),
P(root, combine_groups, (l2, a0, m0, l4, a2, m2)),
P(root, combine_groups, (l3, a1, m1, l4, a2, m2))])
def test_match_combine_groups_identifier_group_no_const(self):
ab0, ab1 = root = tree('ab + ab')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, ab0, ab0, 1, ab1, ab1))])
def test_match_combine_groups_identifier_group_single_const(self):
m0, m1 = root = tree('ab + 2ab')
(l2, a), b = m1
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, m0, m0, l2, a * b, m1))])
def test_match_combine_groups_identifier_group_unordered(self):
m0, m1 = root = tree('ab + ba')
b, a = m1
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, m0, m0, 1, b * a, m1))])
def test_combine_groups_simple(self):
root, l1 = tree('a + a,1')
a0, a1 = root
self.assertEqualNodes(combine_groups(root, (1, a0, a0, 1, a1, a1)),
(l1 + 1) * a0)
def test_combine_groups_nary(self):
root, l1 = tree('ab + b + ba,1')
abb, ba = root
ab, b = abb
self.assertEqualNodes(combine_groups(root, (1, ab, ab, 1, ba, ba)),
(l1 + 1) * ab + b)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment