Procházet zdrojové kódy

Added group combination to rules.

Taddeus Kroes před 14 roky
rodič
revize
6f628ddddb
3 změnil soubory, kde provedl 134 přidání a 4 odebrání
  1. 3 1
      src/rules/__init__.py
  2. 62 0
      src/rules/groups.py
  3. 69 3
      tests/test_rules_groups.py

+ 3 - 1
src/rules/__init__.py

@@ -1,5 +1,6 @@
 from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW
 from .poly import match_combine_polynomes
+from .groups import match_combine_groups
 from .factors import match_expand
 from .powers import match_add_exponents, match_subtract_exponents, \
         match_multiply_exponents, match_duplicate_exponent, \
@@ -10,7 +11,8 @@ from .fractions import match_constant_division, match_add_constant_fractions, \
 
 
 RULES = {
-        OP_ADD: [match_add_constant_fractions, match_combine_polynomes],
+        OP_ADD: [match_add_constant_fractions, match_combine_groups, \
+                 match_combine_polynomes],
         OP_MUL: [match_expand, match_add_exponents, \
                  match_expand_and_add_fractions],
         OP_DIV: [match_subtract_exponents, match_divide_numerics, \

+ 62 - 0
src/rules/groups.py

@@ -0,0 +1,62 @@
+from itertools import combinations
+
+from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
+        ExpressionLeaf as Leaf
+from ..possibilities import Possibility as P, MESSAGES
+from .utils import nary_node
+
+
+def match_combine_groups(node):
+    """
+    Match possible combinations of groups of expressions using non-strict
+    equivalence.
+
+    Examples:
+    a + a     ->  2a
+    a + 2a    ->  3a
+    ab + ab   ->  2ab
+    ab + 2ab  ->  3ab
+    ab + ba   ->  2ab
+    """
+    assert node.is_op(OP_ADD)
+
+    p = []
+    groups = []
+
+    for n in node.get_scope():
+        groups.append((1, n, n))
+
+        # Each number multiplication yields a group, multiple occurences of
+        # the same group can be replaced by a single one
+        if n.is_op(OP_MUL):
+            scope = n.get_scope()
+            l = len(scope)
+
+            for i, sub_node in enumerate(scope):
+                if sub_node.is_numeric():
+                    others = [scope[j] for j in range(i) + range(i + 1, l)]
+                    g = others[0] if len(others) == 1 else Node('*', *others)
+                    groups.append((sub_node, g, n))
+
+    for g0, g1 in combinations(groups, 2):
+        if g0[1].equals(g1[1]):
+            p.append(P(node, combine_groups, g0 + g1))
+
+    return p
+
+
+def combine_groups(root, args):
+    c0, g0, n0, c1, g1, n1 = args
+
+    scope = root.get_scope()
+
+    if not isinstance(c0, Leaf):
+        c0 = Leaf(c0)
+
+    # Replace the left node with the new expression
+    scope[scope.index(n0)] = (c0 + c1) * g0
+
+    # Remove the right node
+    scope.remove(n1)
+
+    return nary_node('+', scope)

+ 69 - 3
tests/test_rules_groups.py

@@ -1,9 +1,75 @@
-from src.rules.groups import match_combine_groups
+from src.rules.groups import match_combine_groups, combine_groups
 from src.possibilities import Possibility as P
 from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesGroups(RulesTestCase):
 
-    def test_(self):
-        pass
+    def test_match_combine_groups_no_const(self):
+        a0, a1 = root = tree('a + a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, a0, a0, 1, a1, a1))])
+
+    def test_match_combine_groups_single_const(self):
+        a0, mul = root = tree('a + 2a')
+        l2, a1 = mul
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, a0, a0, l2, a1, mul))])
+
+    def test_match_combine_groups_two_const(self):
+        ((l2, a0), b), (l3, a1) = (m0, b), m1 = root = tree('2a + b + 3a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (l2, a0, m0, l3, a1, m1))])
+
+    def test_match_combine_groups_n_const(self):
+        ((l2, a0), (l3, a1)), (l4, a2) = (m0, m1), m2 = root = tree('2a+3a+4a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (l2, a0, m0, l3, a1, m1)),
+                 P(root, combine_groups, (l2, a0, m0, l4, a2, m2)),
+                 P(root, combine_groups, (l3, a1, m1, l4, a2, m2))])
+
+    def test_match_combine_groups_identifier_group_no_const(self):
+        ab0, ab1 = root = tree('ab + ab')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, ab0, ab0, 1, ab1, ab1))])
+
+    def test_match_combine_groups_identifier_group_single_const(self):
+        m0, m1 = root = tree('ab + 2ab')
+        (l2, a), b = m1
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, m0, m0, l2, a * b, m1))])
+
+    def test_match_combine_groups_identifier_group_unordered(self):
+        m0, m1 = root = tree('ab + ba')
+        b, a = m1
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, m0, m0, 1, b * a, m1))])
+
+    def test_combine_groups_simple(self):
+        root, l1 = tree('a + a,1')
+        a0, a1 = root
+
+        self.assertEqualNodes(combine_groups(root, (1, a0, a0, 1, a1, a1)),
+                              (l1 + 1) * a0)
+
+    def test_combine_groups_nary(self):
+        root, l1 = tree('ab + b + ba,1')
+        abb, ba = root
+        ab, b = abb
+
+        self.assertEqualNodes(combine_groups(root, (1, ab, ab, 1, ba, ba)),
+                              (l1 + 1) * ab + b)