Commit 5f84e863 authored by Taddeus Kroes's avatar Taddeus Kroes

Improved 'factor' rules so that 'Expand ...'-hints make more sense.

parent 76634490
from itertools import product, combinations from itertools import product
from ..node import Scope, OP_ADD, OP_MUL from .utils import is_numeric_node
from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
def is_expandable(node):
"""
Check if a node is expandable. Only additions that consist of not only
numerics can be expanded.
"""
return node.is_op(OP_ADD) \
and not all(map(is_numeric_node, Scope(node)))
def match_expand(node): def match_expand(node):
""" """
a(b + c) -> ab + ac Expand multiplication of non-numeric additions.
(b + c)a -> ab + ac
Examples:
(a + b)(c + d) -> ac + ad + bc + bd (a + b)(c + d) -> ac + ad + bc + bd
(b + c)a -> ab + ac
a(b + c) -> ab + ac
""" """
assert node.is_op(OP_MUL) assert node.is_op(OP_MUL)
p = [] p = []
leaves = []
additions = []
scope = Scope(node) scope = Scope(node)
l = len(scope)
for n in scope: for distance in range(1, l):
if n.is_leaf: for i, left in enumerate(scope[:-distance]):
leaves.append(n) right = scope[i + distance]
elif n.op == OP_ADD: l_expandable = is_expandable(left)
# If the addition only contains numerics, do not expand r_expandable = is_expandable(right)
if not filter(lambda n: not n.is_numeric(), Scope(n)):
continue
additions.append(n)
for l, a in product(leaves, additions): if l_expandable and r_expandable:
p.append(P(node, expand_single, (scope, l, a))) p.append(P(node, expand_double, (scope, left, right)))
elif l_expandable ^ r_expandable:
for a0, a1 in combinations(additions, 2): p.append(P(node, expand_single, (scope, left, right)))
p.append(P(node, expand_double, (scope, a0, a1)))
return p return p
def expand_single(root, args): def expand(root, args):
""" """
Combine a leaf (a) multiplied with an addition of two expressions (a + b)(c + d) -> ac + ad + bc + bd
(b + c) to an addition of two multiplications. (a + b)c -> ac + bc
a(b + c) -> ab + ac a(b + c) -> ab + ac
(b + c)a -> ab + ac etc..
""" """
scope, a, bc = args scope, left, right = args
b, c = bc
# Replace 'a' with the new expression
scope.replace(a, a * b + a * c)
# Remove the addition left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
scope.remove(bc) right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
return scope.as_nary_node() add_scope = [l * r for l, r in product(left_scope, right_scope)]
add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
add.negated = left.negated + right.negated
scope.replace(left, add)
scope.remove(right)
MESSAGES[expand_single] = _('Expand {2}({3}).') return scope.as_nary_node()
def expand_double(root, args): def expand_double(root, args):
""" return expand(root, args)
Rewrite two multiplied additions to an addition of four multiplications.
(a + b)(c + d) -> ac + ad + bc + bd
"""
scope, ab, cd = args
(a, b), (c, d) = ab, cd
# Replace 'a + b' with the new expression MESSAGES[expand_double] = _('Expand ({2})({3}).')
scope.replace(ab, a * c + a * d + b * c + b * d)
# Remove the right addition
scope.remove(cd)
return scope.as_nary_node() def expand_single(root, args):
return expand(root, args)
MESSAGES[expand_double] = _('Expand ({2})({3}).') MESSAGES[expand_single] = _('Expand ({2})({3}).')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment