Commit 5f84e863 authored by Taddeus Kroes's avatar Taddeus Kroes

Improved 'factor' rules so that 'Expand ...'-hints make more sense.

parent 76634490
from itertools import product, combinations
from itertools import product
from ..node import Scope, OP_ADD, OP_MUL
from .utils import is_numeric_node
from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
def is_expandable(node):
"""
Check if a node is expandable. Only additions that consist of not only
numerics can be expanded.
"""
return node.is_op(OP_ADD) \
and not all(map(is_numeric_node, Scope(node)))
def match_expand(node):
"""
a(b + c) -> ab + ac
(b + c)a -> ab + ac
Expand multiplication of non-numeric additions.
Examples:
(a + b)(c + d) -> ac + ad + bc + bd
(b + c)a -> ab + ac
a(b + c) -> ab + ac
"""
assert node.is_op(OP_MUL)
p = []
leaves = []
additions = []
scope = Scope(node)
l = len(scope)
for n in scope:
if n.is_leaf:
leaves.append(n)
elif n.op == OP_ADD:
# If the addition only contains numerics, do not expand
if not filter(lambda n: not n.is_numeric(), Scope(n)):
continue
additions.append(n)
for distance in range(1, l):
for i, left in enumerate(scope[:-distance]):
right = scope[i + distance]
l_expandable = is_expandable(left)
r_expandable = is_expandable(right)
for l, a in product(leaves, additions):
p.append(P(node, expand_single, (scope, l, a)))
for a0, a1 in combinations(additions, 2):
p.append(P(node, expand_double, (scope, a0, a1)))
if l_expandable and r_expandable:
p.append(P(node, expand_double, (scope, left, right)))
elif l_expandable ^ r_expandable:
p.append(P(node, expand_single, (scope, left, right)))
return p
def expand_single(root, args):
def expand(root, args):
"""
Combine a leaf (a) multiplied with an addition of two expressions
(b + c) to an addition of two multiplications.
a(b + c) -> ab + ac
(b + c)a -> ab + ac
(a + b)(c + d) -> ac + ad + bc + bd
(a + b)c -> ac + bc
a(b + c) -> ab + ac
etc..
"""
scope, a, bc = args
b, c = bc
# Replace 'a' with the new expression
scope.replace(a, a * b + a * c)
scope, left, right = args
# Remove the addition
scope.remove(bc)
left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
return scope.as_nary_node()
add_scope = [l * r for l, r in product(left_scope, right_scope)]
add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
add.negated = left.negated + right.negated
scope.replace(left, add)
scope.remove(right)
MESSAGES[expand_single] = _('Expand {2}({3}).')
return scope.as_nary_node()
def expand_double(root, args):
"""
Rewrite two multiplied additions to an addition of four multiplications.
return expand(root, args)
(a + b)(c + d) -> ac + ad + bc + bd
"""
scope, ab, cd = args
(a, b), (c, d) = ab, cd
# Replace 'a + b' with the new expression
scope.replace(ab, a * c + a * d + b * c + b * d)
MESSAGES[expand_double] = _('Expand ({2})({3}).')
# Remove the right addition
scope.remove(cd)
return scope.as_nary_node()
def expand_single(root, args):
return expand(root, args)
MESSAGES[expand_double] = _('Expand ({2})({3}).')
MESSAGES[expand_single] = _('Expand ({2})({3}).')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment