Browse Source

Improved 'factor' rules so that 'Expand ...'-hints make more sense.

Taddeus Kroes 14 years ago
parent
commit
5f84e8633f
1 changed files with 46 additions and 47 deletions
  1. 46 47
      src/rules/factors.py

+ 46 - 47
src/rules/factors.py

@@ -1,81 +1,80 @@
-from itertools import product, combinations
+from itertools import product
 
-from ..node import Scope, OP_ADD, OP_MUL
+from .utils import is_numeric_node
+from ..node import ExpressionNode as N, Scope, OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
 
+def is_expandable(node):
+    """
+    Check if a node is expandable. Only additions that consist of not only
+    numerics can be expanded.
+    """
+    return node.is_op(OP_ADD) \
+            and not all(map(is_numeric_node, Scope(node)))
+
+
 def match_expand(node):
     """
-    a(b + c)        ->  ab + ac
-    (b + c)a        ->  ab + ac
+    Expand multiplication of non-numeric additions.
+
+    Examples:
     (a + b)(c + d)  ->  ac + ad + bc + bd
+    (b + c)a        ->  ab + ac
+    a(b + c)        ->  ab + ac
     """
     assert node.is_op(OP_MUL)
 
     p = []
-    leaves = []
-    additions = []
     scope = Scope(node)
+    l = len(scope)
 
-    for n in scope:
-        if n.is_leaf:
-            leaves.append(n)
-        elif n.op == OP_ADD:
-            # If the addition only contains numerics, do not expand
-            if not filter(lambda n: not n.is_numeric(), Scope(n)):
-                continue
-
-            additions.append(n)
+    for distance in range(1, l):
+        for i, left in enumerate(scope[:-distance]):
+            right = scope[i + distance]
+            l_expandable = is_expandable(left)
+            r_expandable = is_expandable(right)
 
-    for l, a in product(leaves, additions):
-        p.append(P(node, expand_single, (scope, l, a)))
-
-    for a0, a1 in combinations(additions, 2):
-        p.append(P(node, expand_double, (scope, a0, a1)))
+            if l_expandable and r_expandable:
+                p.append(P(node, expand_double, (scope, left, right)))
+            elif l_expandable ^ r_expandable:
+                p.append(P(node, expand_single, (scope, left, right)))
 
     return p
 
 
-def expand_single(root, args):
+def expand(root, args):
     """
-    Combine a leaf (a) multiplied with an addition of two expressions
-    (b + c) to an addition of two multiplications.
-
-    a(b + c)  ->  ab + ac
-    (b + c)a  ->  ab + ac
+    (a + b)(c + d)  ->  ac + ad + bc + bd
+    (a + b)c        ->  ac + bc
+    a(b + c)        ->  ab + ac
+    etc..
     """
-    scope, a, bc = args
-    b, c = bc
-
-    # Replace 'a' with the new expression
-    scope.replace(a, a * b + a * c)
+    scope, left, right = args
 
-    # Remove the addition
-    scope.remove(bc)
+    left_scope = Scope(left) if left.is_op(OP_ADD) else [left]
+    right_scope = Scope(right) if right.is_op(OP_ADD) else [right]
 
-    return scope.as_nary_node()
+    add_scope = [l * r for l, r in product(left_scope, right_scope)]
+    add = Scope(N(OP_ADD, *add_scope)).as_nary_node()
+    add.negated = left.negated + right.negated
 
+    scope.replace(left, add)
+    scope.remove(right)
 
-MESSAGES[expand_single] = _('Expand {2}({3}).')
+    return scope.as_nary_node()
 
 
 def expand_double(root, args):
-    """
-    Rewrite two multiplied additions to an addition of four multiplications.
+    return expand(root, args)
 
-    (a + b)(c + d)  ->  ac + ad + bc + bd
-    """
-    scope, ab, cd = args
-    (a, b), (c, d) = ab, cd
 
-    # Replace 'a + b' with the new expression
-    scope.replace(ab, a * c + a * d + b * c + b * d)
+MESSAGES[expand_double] = _('Expand ({2})({3}).')
 
-    # Remove the right addition
-    scope.remove(cd)
 
-    return scope.as_nary_node()
+def expand_single(root, args):
+    return expand(root, args)
 
 
-MESSAGES[expand_double] = _('Expand ({2})({3}).')
+MESSAGES[expand_single] = _('Expand ({2})({3}).')