Prechádzať zdrojové kódy

Replaced get_scope_except function by better use of standard libraries.

Taddeus Kroes 14 rokov pred
rodič
commit
206246be10
2 zmenil súbory, kde vykonal 13 pridanie a 78 odobranie
  1. 3 68
      src/node.py
  2. 10 10
      src/rules/poly.py

+ 3 - 68
src/node.py

@@ -114,19 +114,18 @@ class ExpressionNode(Node, ExpressionBase):
                         return (power[0].value, power[1].value, coeff.value)
 
     def get_scope(self):
+        """"""
         scope = []
+        #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
 
         for child in self:
-            if not isinstance(child, Leaf) and child.op == self.op:
+            if not child.is_leaf() and child.op & self.op:
                 scope += child.get_scope()
             else:
                 scope.append(child)
 
         return scope
 
-    def get_scope_except(self, *args):
-        return list(set(self.get_scope()) - set(args))
-
 
 class ExpressionLeaf(Leaf, ExpressionBase):
     def __init__(self, *args, **kwargs):
@@ -146,67 +145,3 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         self.parent.nodes[pos] = node
         node.parent = self.parent
         self.parent = None
-
-
-if __name__ == '__main__':  # pragma: nocover
-    l0 = ExpressionLeaf(3)
-    l1 = ExpressionLeaf(4)
-    l2 = ExpressionLeaf(5)
-    l3 = ExpressionLeaf(7)
-
-    n0 = ExpressionNode('+', l0, l1)
-    n1 = ExpressionNode('+', l2, l3)
-    n2 = ExpressionNode('*', n0, n1)
-
-    print n2
-
-    N = ExpressionNode
-
-    def rewrite_multiply(node):
-        a, b = node[0]
-        c, d = node[1]
-
-        ac = N('*', a, c)
-        ad = N('*', a, d)
-        bc = N('*', b, c)
-        bd = N('*', b, d)
-
-        res = N('+', N('+', N('+', ac, ad), bc), bd)
-
-        return res
-
-    possibilities = [
-            (n0, lambda (x, y): ExpressionLeaf(x.value + y.value)),
-            (n1, lambda (x, y): ExpressionLeaf(x.value + y.value)),
-            (n2, rewrite_multiply),
-            ]
-
-    print '\n--- after rule 2 ---\n'
-
-    n_, method = possibilities[2]
-    new = method(n_)
-
-    print new
-
-    print '\n--- original graph ---\n'
-
-    print n2
-
-    print '\n--- apply rule 0 ---\n'
-
-    n_, method = possibilities[0]
-    new = method(n_)
-    n_.replace(new)
-
-    print n2
-
-    # Revert rule 0
-    new.replace(n_)
-
-    print '\n--- apply rule 1 ---\n'
-
-    n_, method = possibilities[1]
-    new = method(n_)
-    n_.replace(new)
-
-    print n2

+ 10 - 10
src/rules/poly.py

@@ -39,21 +39,21 @@ def match_expand(node):
 
 def expand_single(root, args):
     """
-    Combine a leaf (left) multiplied with an addition of two expressions
-    (right) to an addition of two multiplications.
+    Combine a leaf (a) multiplied with an addition of two expressions
+    (b + c) to an addition of two multiplications.
 
     >>> a * (b + c) -> a * b + a * c
     """
-    left, right = args
-    scope = root.get_scope_except(right)
+    a, bc = args
+    b, c = bc
+    scope = root.get_scope()
 
-    replacement = Node('+', Node('*', left, right[0]), \
-                            Node('*', left, right[1]))
+    # Replace 'a' with the new expression
+    scope[scope.index(a)] = Node('+', Node('*', a, b), \
+                                      Node('*', a, c))
 
-    for i, n in enumerate(scope):
-        if n == left:
-            scope[i] = replacement
-            break
+    # Remove the old addition
+    scope.remove(bc)
 
     return nary_node('*', scope)