소스 검색

Merge branch 'master' of kompiler.org:trs

Sander Mathijs van Veen 14 년 전
부모
커밋
7f05638228

+ 54 - 1
src/node.py

@@ -227,7 +227,9 @@ class ExpressionNode(Node, ExpressionBase):
         return (self[1], self[0], ExpressionLeaf(1))
 
     def get_scope(self):
-        """"""
+        """
+        Find all n nodes within the n-ary scope of this operator.
+        """
         scope = []
         #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
 
@@ -241,6 +243,47 @@ class ExpressionNode(Node, ExpressionBase):
 
         return scope
 
+    def equals(self, other):
+        """
+        Perform a non-strict equivalence check between two nodes:
+        - If the other node is a leaf, it cannot be equal to this node.
+        - If their operators differ, the nodes are not equal.
+        - If both nodes are additions or both are multiplications, match each
+          node in one scope to one in the other (an injective relationship).
+          Any difference in order of the scopes is irrelevant.
+        - If both nodes are divisions, the nominator and denominator have to be
+          non-strictly equal.
+        """
+        if not other.is_op(self.op):
+            return False
+
+        if self.op in (OP_ADD, OP_MUL):
+            s0 = self.get_scope()
+            s1 = set(other.get_scope())
+
+            # Scopes sould be of equal size
+            if len(s0) != len(s1):
+                return False
+
+            # Each node in one scope should have an image node in the other
+            matched = set()
+
+            for n0 in s0:
+                found = False
+
+                for n1 in s1 - matched:
+                    if n0.equals(n1):
+                        found = True
+                        matched.add(n1)
+                        break
+
+                if not found:
+                    return False
+        elif self.op == OP_DIV:
+            return self[0].equals(other[0]) and self[1].equals(other[1])
+
+        return True
+
 
 class ExpressionLeaf(Leaf, ExpressionBase):
     def __init__(self, *args, **kwargs):
@@ -249,6 +292,9 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         self.type = TYPE_MAP[type(args[0])]
 
     def __eq__(self, other):
+        """
+        Check strict equivalence.
+        """
         other_type = type(other)
 
         if other_type in TYPE_MAP:
@@ -256,6 +302,13 @@ class ExpressionLeaf(Leaf, ExpressionBase):
 
         return other.type == self.type and self.value == other.value
 
+    def equals(self, other):
+        """
+        Check non-strict equivalence.
+        Between leaves, this is the same as strict equivalence.
+        """
+        return self == other
+
     def extract_polynome_properties(self):
         """
         An expression leaf will return the polynome tuple (1, r, 1), where r is

+ 3 - 1
src/rules/__init__.py

@@ -1,5 +1,6 @@
 from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW
 from .poly import match_combine_polynomes
+from .groups import match_combine_groups
 from .factors import match_expand
 from .powers import match_add_exponents, match_subtract_exponents, \
         match_multiply_exponents, match_duplicate_exponent, \
@@ -10,7 +11,8 @@ from .fractions import match_constant_division, match_add_constant_fractions, \
 
 
 RULES = {
-        OP_ADD: [match_add_constant_fractions, match_combine_polynomes],
+        OP_ADD: [match_add_constant_fractions, match_combine_groups, \
+                 match_combine_polynomes],
         OP_MUL: [match_expand, match_add_exponents, \
                  match_expand_and_add_fractions],
         OP_DIV: [match_subtract_exponents, match_divide_numerics, \

+ 62 - 0
src/rules/groups.py

@@ -0,0 +1,62 @@
+from itertools import combinations
+
+from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
+        ExpressionLeaf as Leaf
+from ..possibilities import Possibility as P, MESSAGES
+from .utils import nary_node
+
+
+def match_combine_groups(node):
+    """
+    Match possible combinations of groups of expressions using non-strict
+    equivalence.
+
+    Examples:
+    a + a     ->  2a
+    a + 2a    ->  3a
+    ab + ab   ->  2ab
+    ab + 2ab  ->  3ab
+    ab + ba   ->  2ab
+    """
+    assert node.is_op(OP_ADD)
+
+    p = []
+    groups = []
+
+    for n in node.get_scope():
+        groups.append((1, n, n))
+
+        # Each number multiplication yields a group, multiple occurences of
+        # the same group can be replaced by a single one
+        if n.is_op(OP_MUL):
+            scope = n.get_scope()
+            l = len(scope)
+
+            for i, sub_node in enumerate(scope):
+                if sub_node.is_numeric():
+                    others = [scope[j] for j in range(i) + range(i + 1, l)]
+                    g = others[0] if len(others) == 1 else Node('*', *others)
+                    groups.append((sub_node, g, n))
+
+    for g0, g1 in combinations(groups, 2):
+        if g0[1].equals(g1[1]):
+            p.append(P(node, combine_groups, g0 + g1))
+
+    return p
+
+
+def combine_groups(root, args):
+    c0, g0, n0, c1, g1, n1 = args
+
+    scope = root.get_scope()
+
+    if not isinstance(c0, Leaf):
+        c0 = Leaf(c0)
+
+    # Replace the left node with the new expression
+    scope[scope.index(n0)] = (c0 + c1) * g0
+
+    # Remove the right node
+    scope.remove(n1)
+
+    return nary_node('+', scope)

+ 6 - 0
tests/rulestestcase.py

@@ -1,5 +1,11 @@
 import unittest
 from src.node import ExpressionNode
+from src.parser import Parser
+from tests.parser import ParserWrapper
+
+
+def tree(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp])
 
 
 class RulesTestCase(unittest.TestCase):

+ 51 - 0
tests/test_node.py

@@ -1,6 +1,7 @@
 import unittest
 
 from src.node import ExpressionNode as N, ExpressionLeaf as L
+from tests.rulestestcase import tree
 
 
 class TestNode(unittest.TestCase):
@@ -88,3 +89,53 @@ class TestNode(unittest.TestCase):
     def test_get_scope_nested_deep(self):
         plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
         self.assertEqual(plus.get_scope(), self.l)
+
+    def test_equals_node_leaf(self):
+        a, b = plus = tree('a + b')
+
+        self.assertFalse(a.equals(plus))
+        self.assertFalse(plus.equals(a))
+
+    def test_equals_other_op(self):
+        plus, mul = tree('a + b, a * b')
+
+        self.assertFalse(plus.equals(mul))
+
+    def test_equals_add(self):
+        p0, p1, p2, p3 = tree('a + b,a + b,b + a, a + c')
+
+        self.assertTrue(p0.equals(p1))
+        self.assertTrue(p0.equals(p2))
+        self.assertFalse(p0.equals(p3))
+        self.assertFalse(p2.equals(p3))
+
+    def test_equals_mul(self):
+        m0, m1, m2, m3 = tree('a * b,a * b,b * a, a * c')
+
+        self.assertTrue(m0.equals(m1))
+        self.assertTrue(m0.equals(m2))
+        self.assertFalse(m0.equals(m3))
+        self.assertFalse(m2.equals(m3))
+
+    def test_equals_nary(self):
+        p0, p1, p2, p3, p4 = \
+                tree('a + b + c,a + c + b,b + a + c,b + c + a, a + b + d')
+
+        self.assertTrue(p0.equals(p1))
+        self.assertTrue(p0.equals(p2))
+        self.assertTrue(p0.equals(p3))
+        self.assertTrue(p1.equals(p2))
+        self.assertTrue(p1.equals(p3))
+        self.assertTrue(p2.equals(p3))
+        self.assertFalse(p2.equals(p4))
+
+    def test_equals_nary_mary(self):
+        m0, m1 = tree('ab,2ab')
+
+        self.assertFalse(m0.equals(m1))
+
+    def test_equals_div(self):
+        d0, d1, d2 = tree('a / b,a / b,b / a')
+
+        self.assertTrue(d0.equals(d1))
+        self.assertFalse(d0.equals(d2))

+ 1 - 1
tests/test_possibilities.py

@@ -2,7 +2,7 @@ import unittest
 
 from src.possibilities import MESSAGES, Possibility as P, filter_duplicates
 from src.rules.numerics import add_numerics
-from tests.test_rules_poly import tree
+from tests.rulestestcase import tree
 
 from src.parser import Parser
 from tests.parser import ParserWrapper

+ 1 - 2
tests/test_rules_factors.py

@@ -1,7 +1,6 @@
 from src.rules.factors import match_expand, expand_single, expand_double
 from src.possibilities import Possibility as P
-from tests.rulestestcase import RulesTestCase
-from tests.test_rules_poly import tree
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesFactors(RulesTestCase):

+ 1 - 2
tests/test_rules_fractions.py

@@ -2,8 +2,7 @@ from src.rules.fractions import match_constant_division, division_by_one, \
         division_of_zero, division_by_self, match_add_constant_fractions, \
         equalize_denominators, add_nominators
 from src.possibilities import Possibility as P
-from tests.test_rules_poly import tree
-from tests.rulestestcase import RulesTestCase
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesFractions(RulesTestCase):

+ 75 - 0
tests/test_rules_groups.py

@@ -0,0 +1,75 @@
+from src.rules.groups import match_combine_groups, combine_groups
+from src.possibilities import Possibility as P
+from tests.rulestestcase import RulesTestCase, tree
+
+
+class TestRulesGroups(RulesTestCase):
+
+    def test_match_combine_groups_no_const(self):
+        a0, a1 = root = tree('a + a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, a0, a0, 1, a1, a1))])
+
+    def test_match_combine_groups_single_const(self):
+        a0, mul = root = tree('a + 2a')
+        l2, a1 = mul
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, a0, a0, l2, a1, mul))])
+
+    def test_match_combine_groups_two_const(self):
+        ((l2, a0), b), (l3, a1) = (m0, b), m1 = root = tree('2a + b + 3a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (l2, a0, m0, l3, a1, m1))])
+
+    def test_match_combine_groups_n_const(self):
+        ((l2, a0), (l3, a1)), (l4, a2) = (m0, m1), m2 = root = tree('2a+3a+4a')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (l2, a0, m0, l3, a1, m1)),
+                 P(root, combine_groups, (l2, a0, m0, l4, a2, m2)),
+                 P(root, combine_groups, (l3, a1, m1, l4, a2, m2))])
+
+    def test_match_combine_groups_identifier_group_no_const(self):
+        ab0, ab1 = root = tree('ab + ab')
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, ab0, ab0, 1, ab1, ab1))])
+
+    def test_match_combine_groups_identifier_group_single_const(self):
+        m0, m1 = root = tree('ab + 2ab')
+        (l2, a), b = m1
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, m0, m0, l2, a * b, m1))])
+
+    def test_match_combine_groups_identifier_group_unordered(self):
+        m0, m1 = root = tree('ab + ba')
+        b, a = m1
+
+        possibilities = match_combine_groups(root)
+        self.assertEqualPos(possibilities,
+                [P(root, combine_groups, (1, m0, m0, 1, b * a, m1))])
+
+    def test_combine_groups_simple(self):
+        root, l1 = tree('a + a,1')
+        a0, a1 = root
+
+        self.assertEqualNodes(combine_groups(root, (1, a0, a0, 1, a1, a1)),
+                              (l1 + 1) * a0)
+
+    def test_combine_groups_nary(self):
+        root, l1 = tree('ab + b + ba,1')
+        abb, ba = root
+        ab, b = abb
+
+        self.assertEqualNodes(combine_groups(root, (1, ab, ab, 1, ba, ba)),
+                              (l1 + 1) * ab + b)

+ 1 - 2
tests/test_rules_numerics.py

@@ -2,8 +2,7 @@ from src.rules.numerics import add_numerics, match_divide_numerics, \
         divide_numerics, match_multiply_numerics, multiply_numerics
 from src.possibilities import Possibility as P
 from src.node import ExpressionLeaf as L
-from tests.rulestestcase import RulesTestCase
-from tests.test_rules_poly import tree
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesNumerics(RulesTestCase):

+ 1 - 7
tests/test_rules_poly.py

@@ -1,13 +1,7 @@
 from src.rules.poly import match_combine_polynomes, combine_polynomes
 from src.rules.numerics import add_numerics
 from src.possibilities import Possibility as P
-from src.parser import Parser
-from tests.parser import ParserWrapper
-from tests.rulestestcase import RulesTestCase
-
-
-def tree(exp, **kwargs):
-    return ParserWrapper(Parser, **kwargs).run([exp])
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesPoly(RulesTestCase):

+ 1 - 2
tests/test_rules_powers.py

@@ -6,8 +6,7 @@ from src.rules.powers import match_add_exponents, add_exponents, \
         match_exponent_to_root, exponent_to_root
 from src.possibilities import Possibility as P
 from src.node import ExpressionNode as N
-from tests.test_rules_poly import tree
-from tests.rulestestcase import RulesTestCase
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesPowers(RulesTestCase):

+ 1 - 2
tests/test_rules_utils.py

@@ -1,7 +1,6 @@
 from src.node import ExpressionNode as N
 from src.rules.utils import nary_node, is_prime, least_common_multiple
-from tests.test_rules_poly import tree
-from tests.rulestestcase import RulesTestCase
+from tests.rulestestcase import RulesTestCase, tree
 
 
 class TestRulesUtils(RulesTestCase):