|
@@ -1,81 +1,80 @@
|
|
|
-from itertools import product, combinations
|
|
|
|
|
|
|
+from itertools import product
|
|
|
|
|
|
|
|
-from ..node import Scope, OP_ADD, OP_MUL
|
|
|
|
|
|
|
+from .utils import is_numeric_node
|
|
|
|
|
+from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
|
|
|
from ..possibilities import Possibility as P, MESSAGES
|
|
from ..possibilities import Possibility as P, MESSAGES
|
|
|
from ..translate import _
|
|
from ..translate import _
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def is_expandable(node):
|
|
|
|
|
+ """
|
|
|
|
|
+ Check if a node is expandable. Only additions that consist of not only
|
|
|
|
|
+ numerics can be expanded.
|
|
|
|
|
+ """
|
|
|
|
|
+ return node.is_op(OP_ADD) \
|
|
|
|
|
+ and not all(map(is_numeric_node, Scope(node)))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def match_expand(node):
|
|
def match_expand(node):
|
|
|
"""
|
|
"""
|
|
|
- a(b + c) -> ab + ac
|
|
|
|
|
- (b + c)a -> ab + ac
|
|
|
|
|
|
|
+ Expand multiplication of non-numeric additions.
|
|
|
|
|
+
|
|
|
|
|
+ Examples:
|
|
|
(a + b)(c + d) -> ac + ad + bc + bd
|
|
(a + b)(c + d) -> ac + ad + bc + bd
|
|
|
|
|
+ (b + c)a -> ab + ac
|
|
|
|
|
+ a(b + c) -> ab + ac
|
|
|
"""
|
|
"""
|
|
|
assert node.is_op(OP_MUL)
|
|
assert node.is_op(OP_MUL)
|
|
|
|
|
|
|
|
p = []
|
|
p = []
|
|
|
- leaves = []
|
|
|
|
|
- additions = []
|
|
|
|
|
scope = Scope(node)
|
|
scope = Scope(node)
|
|
|
|
|
+ l = len(scope)
|
|
|
|
|
|
|
|
- for n in scope:
|
|
|
|
|
- if n.is_leaf:
|
|
|
|
|
- leaves.append(n)
|
|
|
|
|
- elif n.op == OP_ADD:
|
|
|
|
|
- # If the addition only contains numerics, do not expand
|
|
|
|
|
- if not filter(lambda n: not n.is_numeric(), Scope(n)):
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- additions.append(n)
|
|
|
|
|
|
|
+ for distance in range(1, l):
|
|
|
|
|
+ for i, left in enumerate(scope[:-distance]):
|
|
|
|
|
+ right = scope[i + distance]
|
|
|
|
|
+ l_expandable = is_expandable(left)
|
|
|
|
|
+ r_expandable = is_expandable(right)
|
|
|
|
|
|
|
|
- for l, a in product(leaves, additions):
|
|
|
|
|
- p.append(P(node, expand_single, (scope, l, a)))
|
|
|
|
|
-
|
|
|
|
|
- for a0, a1 in combinations(additions, 2):
|
|
|
|
|
- p.append(P(node, expand_double, (scope, a0, a1)))
|
|
|
|
|
|
|
+ if l_expandable and r_expandable:
|
|
|
|
|
+ p.append(P(node, expand_double, (scope, left, right)))
|
|
|
|
|
+ elif l_expandable ^ r_expandable:
|
|
|
|
|
+ p.append(P(node, expand_single, (scope, left, right)))
|
|
|
|
|
|
|
|
return p
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
|
-def expand_single(root, args):
|
|
|
|
|
|
|
+def expand(root, args):
|
|
|
"""
|
|
"""
|
|
|
- Combine a leaf (a) multiplied with an addition of two expressions
|
|
|
|
|
- (b + c) to an addition of two multiplications.
|
|
|
|
|
-
|
|
|
|
|
- a(b + c) -> ab + ac
|
|
|
|
|
- (b + c)a -> ab + ac
|
|
|
|
|
|
|
+ (a + b)(c + d) -> ac + ad + bc + bd
|
|
|
|
|
+ (a + b)c -> ac + bc
|
|
|
|
|
+ a(b + c) -> ab + ac
|
|
|
|
|
+ etc..
|
|
|
"""
|
|
"""
|
|
|
- scope, a, bc = args
|
|
|
|
|
- b, c = bc
|
|
|
|
|
-
|
|
|
|
|
- # Replace 'a' with the new expression
|
|
|
|
|
- scope.replace(a, a * b + a * c)
|
|
|
|
|
|
|
+ scope, left, right = args
|
|
|
|
|
|
|
|
- # Remove the addition
|
|
|
|
|
- scope.remove(bc)
|
|
|
|
|
|
|
+ left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
|
|
|
|
|
+ right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
|
|
|
|
|
|
|
|
- return scope.as_nary_node()
|
|
|
|
|
|
|
+ add_scope = [l * r for l, r in product(left_scope, right_scope)]
|
|
|
|
|
+ add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
|
|
|
|
|
+ add.negated = left.negated + right.negated
|
|
|
|
|
|
|
|
|
|
+ scope.replace(left, add)
|
|
|
|
|
+ scope.remove(right)
|
|
|
|
|
|
|
|
-MESSAGES[expand_single] = _('Expand {2}({3}).')
|
|
|
|
|
|
|
+ return scope.as_nary_node()
|
|
|
|
|
|
|
|
|
|
|
|
|
def expand_double(root, args):
|
|
def expand_double(root, args):
|
|
|
- """
|
|
|
|
|
- Rewrite two multiplied additions to an addition of four multiplications.
|
|
|
|
|
|
|
+ return expand(root, args)
|
|
|
|
|
|
|
|
- (a + b)(c + d) -> ac + ad + bc + bd
|
|
|
|
|
- """
|
|
|
|
|
- scope, ab, cd = args
|
|
|
|
|
- (a, b), (c, d) = ab, cd
|
|
|
|
|
|
|
|
|
|
- # Replace 'a + b' with the new expression
|
|
|
|
|
- scope.replace(ab, a * c + a * d + b * c + b * d)
|
|
|
|
|
|
|
+MESSAGES[expand_double] = _('Expand ({2})({3}).')
|
|
|
|
|
|
|
|
- # Remove the right addition
|
|
|
|
|
- scope.remove(cd)
|
|
|
|
|
|
|
|
|
|
- return scope.as_nary_node()
|
|
|
|
|
|
|
+def expand_single(root, args):
|
|
|
|
|
+ return expand(root, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
-MESSAGES[expand_double] = _('Expand ({2})({3}).')
|
|
|
|
|
|
|
+MESSAGES[expand_single] = _('Expand ({2})({3}).')
|