Skip to content
Snippets Groups Projects
Commit 5f84e863 authored by Taddeus Kroes's avatar Taddeus Kroes
Browse files

Improved 'factor' rules so that 'Expand ...'-hints make more sense.

parent 76634490
No related branches found
No related tags found
No related merge requests found
from itertools import product, combinations
from itertools import product
from ..node import Scope, OP_ADD, OP_MUL
from .utils import is_numeric_node
from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
def is_expandable(node):
"""
Check if a node is expandable. Only additions that consist of not only
numerics can be expanded.
"""
return node.is_op(OP_ADD) \
and not all(map(is_numeric_node, Scope(node)))
def match_expand(node):
"""
a(b + c) -> ab + ac
(b + c)a -> ab + ac
Expand multiplication of non-numeric additions.
Examples:
(a + b)(c + d) -> ac + ad + bc + bd
(b + c)a -> ab + ac
a(b + c) -> ab + ac
"""
assert node.is_op(OP_MUL)
p = []
leaves = []
additions = []
scope = Scope(node)
l = len(scope)
for n in scope:
if n.is_leaf:
leaves.append(n)
elif n.op == OP_ADD:
# If the addition only contains numerics, do not expand
if not filter(lambda n: not n.is_numeric(), Scope(n)):
continue
additions.append(n)
for distance in range(1, l):
for i, left in enumerate(scope[:-distance]):
right = scope[i + distance]
l_expandable = is_expandable(left)
r_expandable = is_expandable(right)
for l, a in product(leaves, additions):
p.append(P(node, expand_single, (scope, l, a)))
for a0, a1 in combinations(additions, 2):
p.append(P(node, expand_double, (scope, a0, a1)))
if l_expandable and r_expandable:
p.append(P(node, expand_double, (scope, left, right)))
elif l_expandable ^ r_expandable:
p.append(P(node, expand_single, (scope, left, right)))
return p
def expand_single(root, args):
def expand(root, args):
"""
Combine a leaf (a) multiplied with an addition of two expressions
(b + c) to an addition of two multiplications.
a(b + c) -> ab + ac
(b + c)a -> ab + ac
(a + b)(c + d) -> ac + ad + bc + bd
(a + b)c -> ac + bc
a(b + c) -> ab + ac
etc..
"""
scope, a, bc = args
b, c = bc
# Replace 'a' with the new expression
scope.replace(a, a * b + a * c)
scope, left, right = args
# Remove the addition
scope.remove(bc)
left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
return scope.as_nary_node()
add_scope = [l * r for l, r in product(left_scope, right_scope)]
add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
add.negated = left.negated + right.negated
scope.replace(left, add)
scope.remove(right)
MESSAGES[expand_single] = _('Expand {2}({3}).')
return scope.as_nary_node()
def expand_double(root, args):
"""
Rewrite two multiplied additions to an addition of four multiplications.
return expand(root, args)
(a + b)(c + d) -> ac + ad + bc + bd
"""
scope, ab, cd = args
(a, b), (c, d) = ab, cd
# Replace 'a + b' with the new expression
scope.replace(ab, a * c + a * d + b * c + b * d)
MESSAGES[expand_double] = _('Expand ({2})({3}).')
# Remove the right addition
scope.remove(cd)
return scope.as_nary_node()
def expand_single(root, args):
return expand(root, args)
MESSAGES[expand_double] = _('Expand ({2})({3}).')
MESSAGES[expand_single] = _('Expand ({2})({3}).')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment