Restructured rules and added match_expand.

parent 49c547c8
from ..node import ExpressionNode as Node, OP_ADD from ..node import ExpressionNode as Node, OP_ADD, OP_MUL
from .poly import match_combine_factors#, match_combine_parentheses from .poly import match_combine_factors, match_expand
RULES = { RULES = {
OP_ADD: [match_combine_factors], OP_ADD: [match_combine_factors],
#OP_MUL: [match_combine_parentheses], OP_MUL: [match_expand],
} }
from itertools import combinations from itertools import combinations
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD
from ..possibilities import Possibility as P from ..possibilities import Possibility as P
def match_expand(node):
"""
a * (b + c) -> ab + ac
"""
if node.type != TYPE_OPERATOR or not node.op & OP_MUL:
return []
p = []
# 'a' parts
left = []
# '(b + c)' parts
right = []
for n in node.get_scope():
if node.type == TYPE_OPERATOR:
if n.op & OP_ADD:
right.append(n)
else:
left.append(n)
if len(left) and len(right):
for l in left:
for r in right:
p.append(P(node, expand_single, l, r))
return p
def expand_single(root, args):
"""
Combine a leaf (left) multiplied with an addition of two expressions
(right) to an addition of two multiplications.
>>> a * (b + c) -> a * b + a * c
"""
left, right = args
others = list(set(root.get_scope()) - {left, right})
def match_combine_factors(node): def match_combine_factors(node):
""" """
n + exp + m -> exp + (n + m) n + exp + m -> exp + (n + m)
k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
""" """
if node.type != TYPE_OPERATOR: if node.type != TYPE_OPERATOR or not node.op & OP_ADD:
return [] return []
p = [] p = []
if node.is_nary():
# Collect all nodes that can be combined # Collect all nodes that can be combined
# Numeric leaves # Numeric leaves
numerics = [] numerics = []
...@@ -22,26 +61,21 @@ def match_combine_factors(node): ...@@ -22,26 +61,21 @@ def match_combine_factors(node):
# (identifier, exponent, coefficient) # (identifier, exponent, coefficient)
orders = [] orders = []
# Nodes that cannot be combined
others = []
for n in node.get_scope(): for n in node.get_scope():
if isinstance(n, Leaf): if node.type == TYPE_OPERATOR:
if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
else:
order = n.get_order() order = n.get_order()
if order: if order:
orders += order orders += order
else: else:
others.append(n) if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
if len(numerics) > 1: if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2): for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1, others))) p.append(P(node, combine_numerics, (num0, num1)))
if len(orders) > 1: if len(orders) > 1:
for order0, order1 in combinations(orders, 2): for order0, order1 in combinations(orders, 2):
...@@ -50,7 +84,7 @@ def match_combine_factors(node): ...@@ -50,7 +84,7 @@ def match_combine_factors(node):
if id0 == id1 and exponent0 == exponent1: if id0 == id1 and exponent0 == exponent1:
# Same identifier and exponent -> combine coefficients # Same identifier and exponent -> combine coefficients
args = order0 + (coeff1,) + (others,) args = order0 + (coeff1,)
p.append(P(node, combine_orders, args)) p.append(P(node, combine_orders, args))
return p return p
...@@ -63,8 +97,8 @@ def combine_numerics(root, args): ...@@ -63,8 +97,8 @@ def combine_numerics(root, args):
Example: Example:
>>> 3 + 4 -> 7 >>> 3 + 4 -> 7
""" """
numerics, others = args others = list(set(root.get_scope()) - set(args))
value = sum([n.value for n in numerics]) value = sum([n.value for n in args])
return nary_node('+', others + [Leaf(value)]) return nary_node('+', others + [Leaf(value)])
......
...@@ -5,7 +5,7 @@ from src.node import ExpressionNode as N, ExpressionLeaf as L ...@@ -5,7 +5,7 @@ from src.node import ExpressionNode as N, ExpressionLeaf as L
from tests.parser import run_expressions from tests.parser import run_expressions
class TestB1Ch8(unittest.TestCase): class TestB1Ch08(unittest.TestCase):
def test_diagnostic_test(self): def test_diagnostic_test(self):
run_expressions(Parser, [ run_expressions(Parser, [
......
...@@ -23,15 +23,15 @@ class TestRules(unittest.TestCase): ...@@ -23,15 +23,15 @@ class TestRules(unittest.TestCase):
l0, l1 = L(1), L(2) l0, l1 = L(1), L(2)
plus = N('+', l0, l1) plus = N('+', l0, l1)
p = match_combine_factors(plus) p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, []))]) self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
def test_match_combine_factors_numeric_combinations(self): def test_match_combine_factors_numeric_combinations(self):
l0, l1, l2 = L(1), L(2), L(2) l0, l1, l2 = L(1), L(2), L(2)
plus = N('+', N('+', l0, l1), l2) plus = N('+', N('+', l0, l1), l2)
p = match_combine_factors(plus) p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, [])), self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
P(plus, combine_numerics, (l0, l2, [])), P(plus, combine_numerics, (l0, l2)),
P(plus, combine_numerics, (l1, l2, []))]) P(plus, combine_numerics, (l1, l2))])
def assertEqualPos(self, possibilities, expected): def assertEqualPos(self, possibilities, expected):
for p, e in zip(possibilities, expected): for p, e in zip(possibilities, expected):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment