Commit 7f056382 authored by Sander Mathijs van Veen's avatar Sander Mathijs van Veen

Merge branch 'master' of kompiler.org:trs

parents d9b28ea2 6f628ddd
...@@ -227,7 +227,9 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -227,7 +227,9 @@ class ExpressionNode(Node, ExpressionBase):
return (self[1], self[0], ExpressionLeaf(1)) return (self[1], self[0], ExpressionLeaf(1))
def get_scope(self): def get_scope(self):
"""""" """
Find all n nodes within the n-ary scope of this operator.
"""
scope = [] scope = []
#op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
...@@ -241,6 +243,47 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -241,6 +243,47 @@ class ExpressionNode(Node, ExpressionBase):
return scope return scope
def equals(self, other):
"""
Perform a non-strict equivalence check between two nodes:
- If the other node is a leaf, it cannot be equal to this node.
- If their operators differ, the nodes are not equal.
- If both nodes are additions or both are multiplications, match each
node in one scope to one in the other (an injective relationship).
Any difference in order of the scopes is irrelevant.
- If both nodes are divisions, the nominator and denominator have to be
non-strictly equal.
"""
if not other.is_op(self.op):
return False
if self.op in (OP_ADD, OP_MUL):
s0 = self.get_scope()
s1 = set(other.get_scope())
# Scopes sould be of equal size
if len(s0) != len(s1):
return False
# Each node in one scope should have an image node in the other
matched = set()
for n0 in s0:
found = False
for n1 in s1 - matched:
if n0.equals(n1):
found = True
matched.add(n1)
break
if not found:
return False
elif self.op == OP_DIV:
return self[0].equals(other[0]) and self[1].equals(other[1])
return True
class ExpressionLeaf(Leaf, ExpressionBase): class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -249,6 +292,9 @@ class ExpressionLeaf(Leaf, ExpressionBase): ...@@ -249,6 +292,9 @@ class ExpressionLeaf(Leaf, ExpressionBase):
self.type = TYPE_MAP[type(args[0])] self.type = TYPE_MAP[type(args[0])]
def __eq__(self, other): def __eq__(self, other):
"""
Check strict equivalence.
"""
other_type = type(other) other_type = type(other)
if other_type in TYPE_MAP: if other_type in TYPE_MAP:
...@@ -256,6 +302,13 @@ class ExpressionLeaf(Leaf, ExpressionBase): ...@@ -256,6 +302,13 @@ class ExpressionLeaf(Leaf, ExpressionBase):
return other.type == self.type and self.value == other.value return other.type == self.type and self.value == other.value
def equals(self, other):
"""
Check non-strict equivalence.
Between leaves, this is the same as strict equivalence.
"""
return self == other
def extract_polynome_properties(self): def extract_polynome_properties(self):
""" """
An expression leaf will return the polynome tuple (1, r, 1), where r is An expression leaf will return the polynome tuple (1, r, 1), where r is
......
from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW
from .poly import match_combine_polynomes from .poly import match_combine_polynomes
from .groups import match_combine_groups
from .factors import match_expand from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \ from .powers import match_add_exponents, match_subtract_exponents, \
match_multiply_exponents, match_duplicate_exponent, \ match_multiply_exponents, match_duplicate_exponent, \
...@@ -10,7 +11,8 @@ from .fractions import match_constant_division, match_add_constant_fractions, \ ...@@ -10,7 +11,8 @@ from .fractions import match_constant_division, match_add_constant_fractions, \
RULES = { RULES = {
OP_ADD: [match_add_constant_fractions, match_combine_polynomes], OP_ADD: [match_add_constant_fractions, match_combine_groups, \
match_combine_polynomes],
OP_MUL: [match_expand, match_add_exponents, \ OP_MUL: [match_expand, match_add_exponents, \
match_expand_and_add_fractions], match_expand_and_add_fractions],
OP_DIV: [match_subtract_exponents, match_divide_numerics, \ OP_DIV: [match_subtract_exponents, match_divide_numerics, \
......
from itertools import combinations
from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
ExpressionLeaf as Leaf
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
def match_combine_groups(node):
"""
Match possible combinations of groups of expressions using non-strict
equivalence.
Examples:
a + a -> 2a
a + 2a -> 3a
ab + ab -> 2ab
ab + 2ab -> 3ab
ab + ba -> 2ab
"""
assert node.is_op(OP_ADD)
p = []
groups = []
for n in node.get_scope():
groups.append((1, n, n))
# Each number multiplication yields a group, multiple occurences of
# the same group can be replaced by a single one
if n.is_op(OP_MUL):
scope = n.get_scope()
l = len(scope)
for i, sub_node in enumerate(scope):
if sub_node.is_numeric():
others = [scope[j] for j in range(i) + range(i + 1, l)]
g = others[0] if len(others) == 1 else Node('*', *others)
groups.append((sub_node, g, n))
for g0, g1 in combinations(groups, 2):
if g0[1].equals(g1[1]):
p.append(P(node, combine_groups, g0 + g1))
return p
def combine_groups(root, args):
c0, g0, n0, c1, g1, n1 = args
scope = root.get_scope()
if not isinstance(c0, Leaf):
c0 = Leaf(c0)
# Replace the left node with the new expression
scope[scope.index(n0)] = (c0 + c1) * g0
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
import unittest import unittest
from src.node import ExpressionNode from src.node import ExpressionNode
from src.parser import Parser
from tests.parser import ParserWrapper
def tree(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp])
class RulesTestCase(unittest.TestCase): class RulesTestCase(unittest.TestCase):
......
import unittest import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L from src.node import ExpressionNode as N, ExpressionLeaf as L
from tests.rulestestcase import tree
class TestNode(unittest.TestCase): class TestNode(unittest.TestCase):
...@@ -88,3 +89,53 @@ class TestNode(unittest.TestCase): ...@@ -88,3 +89,53 @@ class TestNode(unittest.TestCase):
def test_get_scope_nested_deep(self): def test_get_scope_nested_deep(self):
plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3]) plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
self.assertEqual(plus.get_scope(), self.l) self.assertEqual(plus.get_scope(), self.l)
def test_equals_node_leaf(self):
a, b = plus = tree('a + b')
self.assertFalse(a.equals(plus))
self.assertFalse(plus.equals(a))
def test_equals_other_op(self):
plus, mul = tree('a + b, a * b')
self.assertFalse(plus.equals(mul))
def test_equals_add(self):
p0, p1, p2, p3 = tree('a + b,a + b,b + a, a + c')
self.assertTrue(p0.equals(p1))
self.assertTrue(p0.equals(p2))
self.assertFalse(p0.equals(p3))
self.assertFalse(p2.equals(p3))
def test_equals_mul(self):
m0, m1, m2, m3 = tree('a * b,a * b,b * a, a * c')
self.assertTrue(m0.equals(m1))
self.assertTrue(m0.equals(m2))
self.assertFalse(m0.equals(m3))
self.assertFalse(m2.equals(m3))
def test_equals_nary(self):
p0, p1, p2, p3, p4 = \
tree('a + b + c,a + c + b,b + a + c,b + c + a, a + b + d')
self.assertTrue(p0.equals(p1))
self.assertTrue(p0.equals(p2))
self.assertTrue(p0.equals(p3))
self.assertTrue(p1.equals(p2))
self.assertTrue(p1.equals(p3))
self.assertTrue(p2.equals(p3))
self.assertFalse(p2.equals(p4))
def test_equals_nary_mary(self):
m0, m1 = tree('ab,2ab')
self.assertFalse(m0.equals(m1))
def test_equals_div(self):
d0, d1, d2 = tree('a / b,a / b,b / a')
self.assertTrue(d0.equals(d1))
self.assertFalse(d0.equals(d2))
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from src.possibilities import MESSAGES, Possibility as P, filter_duplicates from src.possibilities import MESSAGES, Possibility as P, filter_duplicates
from src.rules.numerics import add_numerics from src.rules.numerics import add_numerics
from tests.test_rules_poly import tree from tests.rulestestcase import tree
from src.parser import Parser from src.parser import Parser
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
......
from src.rules.factors import match_expand, expand_single, expand_double from src.rules.factors import match_expand, expand_single, expand_double
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase from tests.rulestestcase import RulesTestCase, tree
from tests.test_rules_poly import tree
class TestRulesFactors(RulesTestCase): class TestRulesFactors(RulesTestCase):
......
...@@ -2,8 +2,7 @@ from src.rules.fractions import match_constant_division, division_by_one, \ ...@@ -2,8 +2,7 @@ from src.rules.fractions import match_constant_division, division_by_one, \
division_of_zero, division_by_self, match_add_constant_fractions, \ division_of_zero, division_by_self, match_add_constant_fractions, \
equalize_denominators, add_nominators equalize_denominators, add_nominators
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from tests.test_rules_poly import tree from tests.rulestestcase import RulesTestCase, tree
from tests.rulestestcase import RulesTestCase
class TestRulesFractions(RulesTestCase): class TestRulesFractions(RulesTestCase):
......
from src.rules.groups import match_combine_groups, combine_groups
from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase, tree
class TestRulesGroups(RulesTestCase):
def test_match_combine_groups_no_const(self):
a0, a1 = root = tree('a + a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, a0, a0, 1, a1, a1))])
def test_match_combine_groups_single_const(self):
a0, mul = root = tree('a + 2a')
l2, a1 = mul
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, a0, a0, l2, a1, mul))])
def test_match_combine_groups_two_const(self):
((l2, a0), b), (l3, a1) = (m0, b), m1 = root = tree('2a + b + 3a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (l2, a0, m0, l3, a1, m1))])
def test_match_combine_groups_n_const(self):
((l2, a0), (l3, a1)), (l4, a2) = (m0, m1), m2 = root = tree('2a+3a+4a')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (l2, a0, m0, l3, a1, m1)),
P(root, combine_groups, (l2, a0, m0, l4, a2, m2)),
P(root, combine_groups, (l3, a1, m1, l4, a2, m2))])
def test_match_combine_groups_identifier_group_no_const(self):
ab0, ab1 = root = tree('ab + ab')
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, ab0, ab0, 1, ab1, ab1))])
def test_match_combine_groups_identifier_group_single_const(self):
m0, m1 = root = tree('ab + 2ab')
(l2, a), b = m1
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, m0, m0, l2, a * b, m1))])
def test_match_combine_groups_identifier_group_unordered(self):
m0, m1 = root = tree('ab + ba')
b, a = m1
possibilities = match_combine_groups(root)
self.assertEqualPos(possibilities,
[P(root, combine_groups, (1, m0, m0, 1, b * a, m1))])
def test_combine_groups_simple(self):
root, l1 = tree('a + a,1')
a0, a1 = root
self.assertEqualNodes(combine_groups(root, (1, a0, a0, 1, a1, a1)),
(l1 + 1) * a0)
def test_combine_groups_nary(self):
root, l1 = tree('ab + b + ba,1')
abb, ba = root
ab, b = abb
self.assertEqualNodes(combine_groups(root, (1, ab, ab, 1, ba, ba)),
(l1 + 1) * ab + b)
...@@ -2,8 +2,7 @@ from src.rules.numerics import add_numerics, match_divide_numerics, \ ...@@ -2,8 +2,7 @@ from src.rules.numerics import add_numerics, match_divide_numerics, \
divide_numerics, match_multiply_numerics, multiply_numerics divide_numerics, match_multiply_numerics, multiply_numerics
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from src.node import ExpressionLeaf as L from src.node import ExpressionLeaf as L
from tests.rulestestcase import RulesTestCase from tests.rulestestcase import RulesTestCase, tree
from tests.test_rules_poly import tree
class TestRulesNumerics(RulesTestCase): class TestRulesNumerics(RulesTestCase):
......
from src.rules.poly import match_combine_polynomes, combine_polynomes from src.rules.poly import match_combine_polynomes, combine_polynomes
from src.rules.numerics import add_numerics from src.rules.numerics import add_numerics
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from src.parser import Parser from tests.rulestestcase import RulesTestCase, tree
from tests.parser import ParserWrapper
from tests.rulestestcase import RulesTestCase
def tree(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp])
class TestRulesPoly(RulesTestCase): class TestRulesPoly(RulesTestCase):
......
...@@ -6,8 +6,7 @@ from src.rules.powers import match_add_exponents, add_exponents, \ ...@@ -6,8 +6,7 @@ from src.rules.powers import match_add_exponents, add_exponents, \
match_exponent_to_root, exponent_to_root match_exponent_to_root, exponent_to_root
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from src.node import ExpressionNode as N from src.node import ExpressionNode as N
from tests.test_rules_poly import tree from tests.rulestestcase import RulesTestCase, tree
from tests.rulestestcase import RulesTestCase
class TestRulesPowers(RulesTestCase): class TestRulesPowers(RulesTestCase):
......
from src.node import ExpressionNode as N from src.node import ExpressionNode as N
from src.rules.utils import nary_node, is_prime, least_common_multiple from src.rules.utils import nary_node, is_prime, least_common_multiple
from tests.test_rules_poly import tree from tests.rulestestcase import RulesTestCase, tree
from tests.rulestestcase import RulesTestCase
class TestRulesUtils(RulesTestCase): class TestRulesUtils(RulesTestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment