Restructured rules and added match_expand.

parent 49c547c8
from ..node import ExpressionNode as Node, OP_ADD
from .poly import match_combine_factors#, match_combine_parentheses
from ..node import ExpressionNode as Node, OP_ADD, OP_MUL
from .poly import match_combine_factors, match_expand
RULES = {
OP_ADD: [match_combine_factors],
#OP_MUL: [match_combine_parentheses],
OP_MUL: [match_expand],
}
from itertools import combinations
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD
from ..possibilities import Possibility as P
def match_expand(node):
"""
a * (b + c) -> ab + ac
"""
if node.type != TYPE_OPERATOR or not node.op & OP_MUL:
return []
p = []
# 'a' parts
left = []
# '(b + c)' parts
right = []
for n in node.get_scope():
if node.type == TYPE_OPERATOR:
if n.op & OP_ADD:
right.append(n)
else:
left.append(n)
if len(left) and len(right):
for l in left:
for r in right:
p.append(P(node, expand_single, l, r))
return p
def expand_single(root, args):
"""
Combine a leaf (left) multiplied with an addition of two expressions
(right) to an addition of two multiplications.
>>> a * (b + c) -> a * b + a * c
"""
left, right = args
others = list(set(root.get_scope()) - {left, right})
def match_combine_factors(node):
"""
n + exp + m -> exp + (n + m)
k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
"""
if node.type != TYPE_OPERATOR:
if node.type != TYPE_OPERATOR or not node.op & OP_ADD:
return []
p = []
if node.is_nary():
# Collect all nodes that can be combined
# Numeric leaves
numerics = []
# Identifier leaves of all orders, tuple format is;
# (identifier, exponent, coefficient)
orders = []
# Nodes that cannot be combined
others = []
for n in node.get_scope():
if isinstance(n, Leaf):
if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
else:
order = n.get_order()
if order:
orders += order
else:
others.append(n)
if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1, others)))
if len(orders) > 1:
for order0, order1 in combinations(orders, 2):
id0, exponent0, coeff0 = order0
id1, exponent1, coeff1 = order1
if id0 == id1 and exponent0 == exponent1:
# Same identifier and exponent -> combine coefficients
args = order0 + (coeff1,) + (others,)
p.append(P(node, combine_orders, args))
# Collect all nodes that can be combined
# Numeric leaves
numerics = []
# Identifier leaves of all orders, tuple format is;
# (identifier, exponent, coefficient)
orders = []
for n in node.get_scope():
if node.type == TYPE_OPERATOR:
order = n.get_order()
if order:
orders += order
else:
if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1)))
if len(orders) > 1:
for order0, order1 in combinations(orders, 2):
id0, exponent0, coeff0 = order0
id1, exponent1, coeff1 = order1
if id0 == id1 and exponent0 == exponent1:
# Same identifier and exponent -> combine coefficients
args = order0 + (coeff1,)
p.append(P(node, combine_orders, args))
return p
......@@ -63,8 +97,8 @@ def combine_numerics(root, args):
Example:
>>> 3 + 4 -> 7
"""
numerics, others = args
value = sum([n.value for n in numerics])
others = list(set(root.get_scope()) - set(args))
value = sum([n.value for n in args])
return nary_node('+', others + [Leaf(value)])
......
......@@ -5,7 +5,7 @@ from src.node import ExpressionNode as N, ExpressionLeaf as L
from tests.parser import run_expressions
class TestB1Ch8(unittest.TestCase):
class TestB1Ch08(unittest.TestCase):
def test_diagnostic_test(self):
run_expressions(Parser, [
......
......@@ -23,15 +23,15 @@ class TestRules(unittest.TestCase):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, []))])
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
def test_match_combine_factors_numeric_combinations(self):
l0, l1, l2 = L(1), L(2), L(2)
plus = N('+', N('+', l0, l1), l2)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, [])),
P(plus, combine_numerics, (l0, l2, [])),
P(plus, combine_numerics, (l1, l2, []))])
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
P(plus, combine_numerics, (l0, l2)),
P(plus, combine_numerics, (l1, l2))])
def assertEqualPos(self, possibilities, expected):
for p, e in zip(possibilities, expected):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment