Ver código fonte

Group combinations now support negated constants.

Taddeus Kroes 14 anos atrás
pai
commit
6e8d634547
3 arquivos alterados com 51 adições e 30 exclusões
  1. 4 0
      src/rules/factors.py
  2. 18 15
      src/rules/groups.py
  3. 29 15
      tests/test_rules_groups.py

+ 4 - 0
src/rules/factors.py

@@ -21,6 +21,10 @@ def match_expand(node):
         if n.is_leaf:
             leaves.append(n)
         elif n.op == OP_ADD:
+            # If the addition only contains numerics, do not expand
+            if not filter(lambda n: not n.is_numeric(), Scope(n)):
+                continue
+
             additions.append(n)
 
     for args in product(leaves, additions):

+ 18 - 15
src/rules/groups.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
-        OP_ADD, OP_MUL
+        OP_ADD, OP_MUL, nary_node, negate
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -22,31 +22,37 @@ def match_combine_groups(node):
 
     p = []
     groups = []
-    root_scope = Scope(node)
+    scope = Scope(node)
 
-    for n in root_scope:
-        groups.append((1, n, n))
+    for n in scope:
+        if not n.is_numeric():
+            groups.append((Leaf(1), n, n))
 
         # Each number multiplication yields a group, multiple occurences of
         # the same group can be replaced by a single one
         if n.is_op(OP_MUL):
-            scope = Scope(n)
-            l = len(scope)
+            n_scope = Scope(n)
+            l = len(n_scope)
 
-            for i, sub_node in enumerate(scope):
+            for i, sub_node in enumerate(n_scope):
                 if sub_node.is_numeric():
-                    others = [scope[j] for j in range(i) + range(i + 1, l)]
+                    others = [n_scope[j] for j in range(i) + range(i + 1, l)]
 
                     if len(others) == 1:
                         g = others[0]
                     else:
-                        g = Node('*', *others)
+                        g = nary_node('*', others)
 
                     groups.append((sub_node, g, n))
 
-    for g0, g1 in combinations(groups, 2):
-        if g0[1].equals(g1[1]):
-            p.append(P(node, combine_groups, (root_scope,) + g0 + g1))
+    for (c0, g0, n0), (c1, g1, n1) in combinations(groups, 2):
+        if g0.equals(g1, ignore_negation=True):
+            # Move negations to constants
+            c0 = c0.negate(g0.negated)
+            c1 = c1.negate(g1.negated)
+            g0 = negate(g0, 0)
+            g1 = negate(g1, 0)
+            p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
 
     return p
 
@@ -54,9 +60,6 @@ def match_combine_groups(node):
 def combine_groups(root, args):
     scope, c0, g0, n0, c1, g1, n1 = args
 
-    if not isinstance(c0, Leaf) and not isinstance(c0, Node):
-        c0 = Leaf(c0)
-
     # Replace the left node with the new expression
     scope.replace(n0, (c0 + c1) * g0)
 

+ 29 - 15
tests/test_rules_groups.py

@@ -7,20 +7,31 @@ from tests.rulestestcase import RulesTestCase, tree
 class TestRulesGroups(RulesTestCase):
 
     def test_match_combine_groups_no_const(self):
-        a0, a1 = root = tree('a + a')
+        root, l1 = tree('a + a,1')
+        a0, a1 = root
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (Scope(root), l1, a0, a0,
+                                                       l1, a1, a1))])
+
+    def test_match_combine_groups_negation(self):
+        root, l1 = tree('-a + a,1')
+        a0, a1 = root
 
         possibilities = match_combine_groups(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_groups, (Scope(root), 1, a0, a0,
-                                                       1, a1, a1))])
+                [P(root, combine_groups, (Scope(root), -l1, +a0, a0,
+                                                       l1, a1, a1))])
 
     def test_match_combine_groups_single_const(self):
-        a0, mul = root = tree('a + 2a')
+        root, l1 = tree('a + 2a,1')
+        a0, mul = root
         l2, a1 = mul
 
         possibilities = match_combine_groups(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_groups, (Scope(root), 1, a0, a0,
+                [P(root, combine_groups, (Scope(root), l1, a0, a0,
                                                        l2, a1, mul))])
 
     def test_match_combine_groups_two_const(self):
@@ -44,37 +55,40 @@ class TestRulesGroups(RulesTestCase):
                                                        l4, a2, m2))])
 
     def test_match_combine_groups_identifier_group_no_const(self):
-        ab0, ab1 = root = tree('ab + ab')
+        root, l1 = tree('ab + ab,1')
+        ab0, ab1 = root
 
         possibilities = match_combine_groups(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_groups, (Scope(root), 1, ab0, ab0,
-                                                       1, ab1, ab1))])
+                [P(root, combine_groups, (Scope(root), l1, ab0, ab0,
+                                                       l1, ab1, ab1))])
 
     def test_match_combine_groups_identifier_group_single_const(self):
-        m0, m1 = root = tree('ab + 2ab')
+        root, l1 = tree('ab + 2ab,1')
+        m0, m1 = root
         (l2, a), b = m1
 
         possibilities = match_combine_groups(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_groups, (Scope(root), 1, m0, m0,
+                [P(root, combine_groups, (Scope(root), l1, m0, m0,
                                                        l2, a * b, m1))])
 
     def test_match_combine_groups_identifier_group_unordered(self):
-        m0, m1 = root = tree('ab + ba')
+        root, l1 = tree('ab + ba,1')
+        m0, m1 = root
         b, a = m1
 
         possibilities = match_combine_groups(root)
         self.assertEqualPos(possibilities,
-                [P(root, combine_groups, (Scope(root), 1, m0, m0,
-                                                       1, b * a, m1))])
+                [P(root, combine_groups, (Scope(root), l1, m0, m0,
+                                                       l1, b * a, m1))])
 
     def test_combine_groups_simple(self):
         root, l1 = tree('a + a,1')
         a0, a1 = root
 
         self.assertEqualNodes(combine_groups(root,
-                              (Scope(root), 1, a0, a0, 1, a1, a1)),
+                              (Scope(root), l1, a0, a0, l1, a1, a1)),
                               (l1 + 1) * a0)
 
     def test_combine_groups_nary(self):
@@ -83,5 +97,5 @@ class TestRulesGroups(RulesTestCase):
         ab, b = abb
 
         self.assertEqualNodes(combine_groups(root,
-                              (Scope(root), 1, ab, ab, 1, ba, ba)),
+                              (Scope(root), l1, ab, ab, l1, ba, ba)),
                               (l1 + 1) * ab + b)