Commit 6e8d6345 authored by Taddeus Kroes's avatar Taddeus Kroes

Group combinations now support negated constants.

parent 8f90dd67
...@@ -21,6 +21,10 @@ def match_expand(node): ...@@ -21,6 +21,10 @@ def match_expand(node):
if n.is_leaf: if n.is_leaf:
leaves.append(n) leaves.append(n)
elif n.op == OP_ADD: elif n.op == OP_ADD:
# If the addition only contains numerics, do not expand
if not filter(lambda n: not n.is_numeric(), Scope(n)):
continue
additions.append(n) additions.append(n)
for args in product(leaves, additions): for args in product(leaves, additions):
......
from itertools import combinations from itertools import combinations
from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \ from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
OP_ADD, OP_MUL OP_ADD, OP_MUL, nary_node, negate
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -22,31 +22,37 @@ def match_combine_groups(node): ...@@ -22,31 +22,37 @@ def match_combine_groups(node):
p = [] p = []
groups = [] groups = []
root_scope = Scope(node) scope = Scope(node)
for n in root_scope: for n in scope:
groups.append((1, n, n)) if not n.is_numeric():
groups.append((Leaf(1), n, n))
# Each number multiplication yields a group, multiple occurences of # Each number multiplication yields a group, multiple occurences of
# the same group can be replaced by a single one # the same group can be replaced by a single one
if n.is_op(OP_MUL): if n.is_op(OP_MUL):
scope = Scope(n) n_scope = Scope(n)
l = len(scope) l = len(n_scope)
for i, sub_node in enumerate(scope): for i, sub_node in enumerate(n_scope):
if sub_node.is_numeric(): if sub_node.is_numeric():
others = [scope[j] for j in range(i) + range(i + 1, l)] others = [n_scope[j] for j in range(i) + range(i + 1, l)]
if len(others) == 1: if len(others) == 1:
g = others[0] g = others[0]
else: else:
g = Node('*', *others) g = nary_node('*', others)
groups.append((sub_node, g, n)) groups.append((sub_node, g, n))
for g0, g1 in combinations(groups, 2): for (c0, g0, n0), (c1, g1, n1) in combinations(groups, 2):
if g0[1].equals(g1[1]): if g0.equals(g1, ignore_negation=True):
p.append(P(node, combine_groups, (root_scope,) + g0 + g1)) # Move negations to constants
c0 = c0.negate(g0.negated)
c1 = c1.negate(g1.negated)
g0 = negate(g0, 0)
g1 = negate(g1, 0)
p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
return p return p
...@@ -54,9 +60,6 @@ def match_combine_groups(node): ...@@ -54,9 +60,6 @@ def match_combine_groups(node):
def combine_groups(root, args): def combine_groups(root, args):
scope, c0, g0, n0, c1, g1, n1 = args scope, c0, g0, n0, c1, g1, n1 = args
if not isinstance(c0, Leaf) and not isinstance(c0, Node):
c0 = Leaf(c0)
# Replace the left node with the new expression # Replace the left node with the new expression
scope.replace(n0, (c0 + c1) * g0) scope.replace(n0, (c0 + c1) * g0)
......
...@@ -7,20 +7,31 @@ from tests.rulestestcase import RulesTestCase, tree ...@@ -7,20 +7,31 @@ from tests.rulestestcase import RulesTestCase, tree
class TestRulesGroups(RulesTestCase): class TestRulesGroups(RulesTestCase):
def test_match_combine_groups_no_const(self): def test_match_combine_groups_no_const(self):
a0, a1 = root = tree('a + a') root, l1 = tree('a + a,1')
a0, a1 = root
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), l1, a0, a0,
l1, a1, a1))])
def test_match_combine_groups_negation(self):
root, l1 = tree('-a + a,1')
a0, a1 = root
possibilities = match_combine_groups(root) possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), 1, a0, a0, [P(root, combine_groups, (Scope(root), -l1, +a0, a0,
1, a1, a1))]) l1, a1, a1))])
def test_match_combine_groups_single_const(self): def test_match_combine_groups_single_const(self):
a0, mul = root = tree('a + 2a') root, l1 = tree('a + 2a,1')
a0, mul = root
l2, a1 = mul l2, a1 = mul
possibilities = match_combine_groups(root) possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), 1, a0, a0, [P(root, combine_groups, (Scope(root), l1, a0, a0,
l2, a1, mul))]) l2, a1, mul))])
def test_match_combine_groups_two_const(self): def test_match_combine_groups_two_const(self):
...@@ -44,37 +55,40 @@ class TestRulesGroups(RulesTestCase): ...@@ -44,37 +55,40 @@ class TestRulesGroups(RulesTestCase):
l4, a2, m2))]) l4, a2, m2))])
def test_match_combine_groups_identifier_group_no_const(self): def test_match_combine_groups_identifier_group_no_const(self):
ab0, ab1 = root = tree('ab + ab') root, l1 = tree('ab + ab,1')
ab0, ab1 = root
possibilities = match_combine_groups(root) possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), 1, ab0, ab0, [P(root, combine_groups, (Scope(root), l1, ab0, ab0,
1, ab1, ab1))]) l1, ab1, ab1))])
def test_match_combine_groups_identifier_group_single_const(self): def test_match_combine_groups_identifier_group_single_const(self):
m0, m1 = root = tree('ab + 2ab') root, l1 = tree('ab + 2ab,1')
m0, m1 = root
(l2, a), b = m1 (l2, a), b = m1
possibilities = match_combine_groups(root) possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), 1, m0, m0, [P(root, combine_groups, (Scope(root), l1, m0, m0,
l2, a * b, m1))]) l2, a * b, m1))])
def test_match_combine_groups_identifier_group_unordered(self): def test_match_combine_groups_identifier_group_unordered(self):
m0, m1 = root = tree('ab + ba') root, l1 = tree('ab + ba,1')
m0, m1 = root
b, a = m1 b, a = m1
possibilities = match_combine_groups(root) possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, combine_groups, (Scope(root), 1, m0, m0, [P(root, combine_groups, (Scope(root), l1, m0, m0,
1, b * a, m1))]) l1, b * a, m1))])
def test_combine_groups_simple(self): def test_combine_groups_simple(self):
root, l1 = tree('a + a,1') root, l1 = tree('a + a,1')
a0, a1 = root a0, a1 = root
self.assertEqualNodes(combine_groups(root, self.assertEqualNodes(combine_groups(root,
(Scope(root), 1, a0, a0, 1, a1, a1)), (Scope(root), l1, a0, a0, l1, a1, a1)),
(l1 + 1) * a0) (l1 + 1) * a0)
def test_combine_groups_nary(self): def test_combine_groups_nary(self):
...@@ -83,5 +97,5 @@ class TestRulesGroups(RulesTestCase): ...@@ -83,5 +97,5 @@ class TestRulesGroups(RulesTestCase):
ab, b = abb ab, b = abb
self.assertEqualNodes(combine_groups(root, self.assertEqualNodes(combine_groups(root,
(Scope(root), 1, ab, ab, 1, ba, ba)), (Scope(root), l1, ab, ab, l1, ba, ba)),
(l1 + 1) * ab + b) (l1 + 1) * ab + b)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment