Просмотр исходного кода

Clone the scope using deepcopy, if used more than once.

Sander Mathijs van Veen 13 лет назад
Родитель
Сommit
fb74c3df4f

+ 5 - 0
src/node.py

@@ -612,12 +612,17 @@ class Scope(object):
 
 
         return nary_node(self.node.op, nodes).negate(self.node.negated)
         return nary_node(self.node.op, nodes).negate(self.node.negated)
 
 
+    def clone(self):
+        return copy.deepcopy(self)
+
 
 
 def nary_node(operator, scope):
 def nary_node(operator, scope):
     """
     """
     Create a binary expression tree for an n-ary operator. Takes the operator
     Create a binary expression tree for an n-ary operator. Takes the operator
     and a list of expression nodes as arguments.
     and a list of expression nodes as arguments.
     """
     """
+    assert scope
+
     if len(scope) == 1:
     if len(scope) == 1:
         return scope[0]
         return scope[0]
 
 

+ 2 - 1
src/rules/absolute.py

@@ -28,7 +28,8 @@ def match_factor_out_abs_term(node):
     if exp.is_op(OP_MUL):
     if exp.is_op(OP_MUL):
         scope = Scope(exp)
         scope = Scope(exp)
 
 
-        return [P(node, factor_out_abs_term, (scope, n)) for n in scope]
+        return [P(node, factor_out_abs_term, (scope.clone(), n))
+                for n in scope]
 
 
     if exp.is_op(OP_SQRT):
     if exp.is_op(OP_SQRT):
         return [P(node, factor_out_abs_sqrt)]
         return [P(node, factor_out_abs_sqrt)]

+ 3 - 2
src/rules/derivatives.py

@@ -117,7 +117,8 @@ def match_const_deriv_multiplication(node):
 
 
         for n in scope:
         for n in scope:
             if not n.contains(x):
             if not n.contains(x):
-                p.append(P(node, const_deriv_multiplication, (scope, n, x)))
+                p.append(P(node, const_deriv_multiplication,
+                    (scope.clone(), n, x)))
 
 
     return p
     return p
 
 
@@ -345,7 +346,7 @@ def match_sum_product_rule(node):
     else:
     else:
         handler = sum_rule
         handler = sum_rule
 
 
-    return [P(node, handler, (scope, f)) for f in functions]
+    return [P(node, handler, (scope.clone(), f)) for f in functions]
 
 
 
 
 def sum_rule(root, args):
 def sum_rule(root, args):

+ 8 - 6
src/rules/fractions.py

@@ -89,13 +89,14 @@ def match_add_fractions(node):
 
 
         if b == d:
         if b == d:
             # Equal denominators, add nominators to create a single fraction
             # Equal denominators, add nominators to create a single fraction
-            p.append(P(node, add_nominators, (scope, ab, cd)))
+            p.append(P(node, add_nominators, (scope.clone(), ab, cd)))
         elif all(map(is_numeric_node, (a, b, c, d))):
         elif all(map(is_numeric_node, (a, b, c, d))):
             # Denominators are both numeric, rewrite both fractions to the
             # Denominators are both numeric, rewrite both fractions to the
             # least common multiple of their denominators. Later, the
             # least common multiple of their denominators. Later, the
             # nominators will be added
             # nominators will be added
             lcm = least_common_multiple(b.value, d.value)
             lcm = least_common_multiple(b.value, d.value)
-            p.append(P(node, equalize_denominators, (scope, ab, cd, lcm)))
+            p.append(P(node, equalize_denominators,
+                (scope.clone(), ab, cd, lcm)))
 
 
             # Also, add the (non-recommended) possibility to multiply the
             # Also, add the (non-recommended) possibility to multiply the
             # denominators. Do this only if the multiplication is not equal to
             # denominators. Do this only if the multiplication is not equal to
@@ -103,7 +104,8 @@ def match_add_fractions(node):
             mult = b.value * d.value
             mult = b.value * d.value
 
 
             if mult != lcm:
             if mult != lcm:
-                p.append(P(node, equalize_denominators, (scope, ab, cd, mult)))
+                p.append(P(node, equalize_denominators,
+                    (scope.clone(), ab, cd, mult)))
 
 
     for ab, c in product(fractions, numerics):
     for ab, c in product(fractions, numerics):
         a, b = ab
         a, b = ab
@@ -111,7 +113,7 @@ def match_add_fractions(node):
         if a.is_numeric() and b.is_numeric():
         if a.is_numeric() and b.is_numeric():
             # Fraction of constants added to a constant -> create a single
             # Fraction of constants added to a constant -> create a single
             # constant fraction
             # constant fraction
-            p.append(P(node, constant_to_fraction, (scope, ab, c)))
+            p.append(P(node, constant_to_fraction, (scope.clone(), ab, c)))
 
 
     return p
     return p
 
 
@@ -189,11 +191,11 @@ def match_multiply_fractions(node):
     fractions, others = partition(lambda n: n.is_op(OP_DIV), scope)
     fractions, others = partition(lambda n: n.is_op(OP_DIV), scope)
 
 
     for ab, cd in combinations(fractions, 2):
     for ab, cd in combinations(fractions, 2):
-        p.append(P(node, multiply_fractions, (scope, ab, cd)))
+        p.append(P(node, multiply_fractions, (scope.clone(), ab, cd)))
 
 
     for ab, c in product(fractions, others):
     for ab, c in product(fractions, others):
         if evals_to_numeric(c) or not evals_to_numeric(ab):
         if evals_to_numeric(c) or not evals_to_numeric(ab):
-            p.append(P(node, multiply_with_fraction, (scope, ab, c)))
+            p.append(P(node, multiply_with_fraction, (scope.clone(), ab, c)))
 
 
     return p
     return p
 
 

+ 2 - 2
src/rules/goniometry.py

@@ -25,10 +25,10 @@ def match_add_quadrants(node):
                 continue
                 continue
 
 
             if not sin_q.negated and not cos_q.negated:
             if not sin_q.negated and not cos_q.negated:
-                p.append(P(node, add_quadrants, (scope, sin_q, cos_q)))
+                p.append(P(node, add_quadrants, (scope.clone(), sin_q, cos_q)))
             elif sin_q.negated == 1 and cos_q.negated == 1:
             elif sin_q.negated == 1 and cos_q.negated == 1:
                 p.append(P(node, factor_out_quadrant_negation,
                 p.append(P(node, factor_out_quadrant_negation,
-                    (scope, sin_q, cos_q)))
+                    (scope.clone(), sin_q, cos_q)))
 
 
     return p
     return p
 
 

+ 4 - 2
src/rules/groups.py

@@ -54,7 +54,8 @@ def match_combine_groups(node):
             c1 = c1.negate(n1.negated)
             c1 = c1.negate(n1.negated)
 
 
         if g0.equals(g1):
         if g0.equals(g1):
-            p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
+            p.append(P(node, combine_groups,
+                (scope.clone(), c0, g0, n0, c1, g1, n1)))
         elif g0.equals(g1, ignore_negation=True):
         elif g0.equals(g1, ignore_negation=True):
             # Move negations to constants
             # Move negations to constants
             c0 = c0.negate(g0.negated)
             c0 = c0.negate(g0.negated)
@@ -62,7 +63,8 @@ def match_combine_groups(node):
             g0 = negate(g0, 0, clone=True)
             g0 = negate(g0, 0, clone=True)
             g1 = negate(g1, 0, clone=True)
             g1 = negate(g1, 0, clone=True)
 
 
-            p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
+            p.append(P(node, combine_groups,
+                (scope.clone(), c0, g0, n0, c1, g1, n1)))
 
 
     return p
     return p
 
 

+ 5 - 4
src/rules/integrals.py

@@ -184,7 +184,7 @@ def match_factor_out_constant(node):
 
 
     for n in scope:
     for n in scope:
         if not n.contains(x):
         if not n.contains(x):
-            p.append(P(node, factor_out_constant, (scope, n)))
+            p.append(P(node, factor_out_constant, (scope.clone(), n)))
 
 
     return p
     return p
 
 
@@ -327,9 +327,9 @@ def match_sum_rule_integral(node):
     scope = Scope(node[0])
     scope = Scope(node[0])
 
 
     if len(scope) == 2:
     if len(scope) == 2:
-        return [P(node, sum_rule_integral, (scope, scope[0]))]
+        return [P(node, sum_rule_integral, (scope.clone(), scope[0]))]
 
 
-    return [P(node, sum_rule_integral, (scope, n)) for n in scope]
+    return [P(node, sum_rule_integral, (scope.clone(), n)) for n in scope]
 
 
 
 
 def sum_rule_integral(root, args):
 def sum_rule_integral(root, args):
@@ -360,7 +360,8 @@ def match_remove_indef_constant(node):
     x = find_variable(node[0])
     x = find_variable(node[0])
     constants = [n for n in scope if not n.contains(x)]
     constants = [n for n in scope if not n.contains(x)]
 
 
-    return [P(node, remove_indef_constant, (scope, c)) for c in constants]
+    return [P(node, remove_indef_constant, (scope.clone(), c))
+            for c in constants]
 
 
 
 
 def remove_indef_constant(root, args):
 def remove_indef_constant(root, args):

+ 3 - 2
src/rules/lineq.py

@@ -177,7 +177,8 @@ def match_multiple_equations(node):
 
 
         # Substitution rule
         # Substitution rule
         if x.is_variable() and eq1.contains(x):
         if x.is_variable() and eq1.contains(x):
-            p.append(P(node, substitute_variable, (scope, x, subs, eq1)))
+            p.append(P(node, substitute_variable,
+                (scope.clone(), x, subs, eq1)))
 
 
     return p
     return p
 
 
@@ -208,7 +209,7 @@ def match_double_case(node):
 
 
     for a, b in combinations(scope, 2):
     for a, b in combinations(scope, 2):
         if a == b:
         if a == b:
-            p.append(P(node, double_case, (scope, a, b)))
+            p.append(P(node, double_case, (scope.clone(), a, b)))
 
 
     return p
     return p
 
 

+ 10 - 7
src/rules/logarithmic.py

@@ -99,16 +99,18 @@ def match_add_logarithms(node):
 
 
         if not log_a.negated and not log_b.negated:
         if not log_a.negated and not log_b.negated:
             # log(a) + log(b)  ->  log(ab)
             # log(a) + log(b)  ->  log(ab)
-            p.append(P(node, add_logarithms, (scope, log_a, log_b)))
+            p.append(P(node, add_logarithms, (scope.clone(), log_a, log_b)))
         elif a_negated and b_negated:
         elif a_negated and b_negated:
             # -log(a) - log(b)  ->  -(log(a) + log(b))
             # -log(a) - log(b)  ->  -(log(a) + log(b))
-            p.append(P(node, expand_negations, (scope, log_a, log_b)))
+            p.append(P(node, expand_negations, (scope.clone(), log_a, log_b)))
         elif not log_a.negated and b_negated and divides(b.value, a.value):
         elif not log_a.negated and b_negated and divides(b.value, a.value):
             # log(a) - log(b)  ->  log(a / b)
             # log(a) - log(b)  ->  log(a / b)
-            p.append(P(node, subtract_logarithms, (scope, log_a, log_b)))
+            p.append(P(node, subtract_logarithms,
+                (scope.clone(), log_a, log_b)))
         elif a_negated and not log_b.negated and  divides(a.value, b.value):
         elif a_negated and not log_b.negated and  divides(a.value, b.value):
             # -log(a) + log(b)  ->  log(b / a)
             # -log(a) + log(b)  ->  log(b / a)
-            p.append(P(node, subtract_logarithms, (scope, log_b, log_a)))
+            p.append(P(node, subtract_logarithms,
+                (scope.clone(), log_b, log_a)))
 
 
     return p
     return p
 
 
@@ -187,7 +189,7 @@ def match_raised_base(node):
             # Add this possibility so that a 'raised_base' possibility is
             # Add this possibility so that a 'raised_base' possibility is
             # generated in the following iteration
             # generated in the following iteration
             p.append(P(node, factor_in_exponent_multiplicant,
             p.append(P(node, factor_in_exponent_multiplicant,
-                       (scope, other, log)))
+                       (scope.clone(), other, log)))
 
 
     return p
     return p
 
 
@@ -309,7 +311,8 @@ def match_factor_in_multiplicant(node):
     p = []
     p = []
 
 
     for constant, logarithm in product(constants, logarithms):
     for constant, logarithm in product(constants, logarithms):
-        p.append(P(node, factor_in_multiplicant, (scope, constant, logarithm)))
+        p.append(P(node, factor_in_multiplicant,
+            (scope.clone(), constant, logarithm)))
 
 
     return p
     return p
 
 
@@ -342,7 +345,7 @@ def match_expand_terms(node):
     if exp.is_op(OP_MUL):
     if exp.is_op(OP_MUL):
         scope = Scope(exp)
         scope = Scope(exp)
 
 
-        return [P(node, expand_multiplication_terms, (scope, n)) \
+        return [P(node, expand_multiplication_terms, (scope.clone(), n)) \
                 for n in ifilterfalse(is_numeric_node, scope)]
                 for n in ifilterfalse(is_numeric_node, scope)]
 
 
     if exp.is_op(OP_DIV):
     if exp.is_op(OP_DIV):

+ 1 - 1
src/rules/negation.py

@@ -19,7 +19,7 @@ def match_negated_factor(node):
 
 
     for factor in scope:
     for factor in scope:
         if factor.negated:
         if factor.negated:
-            p.append(P(node, negated_factor, (scope, factor)))
+            p.append(P(node, negated_factor, (scope.clone(), factor)))
 
 
     return p
     return p
 
 

+ 4 - 4
src/rules/numerics.py

@@ -27,12 +27,12 @@ def match_add_numerics(node):
 
 
     for n in scope:
     for n in scope:
         if n == 0:
         if n == 0:
-            p.append(P(node, remove_zero, (scope, n)))
+            p.append(P(node, remove_zero, (scope.clone(), n)))
         elif n.is_numeric():
         elif n.is_numeric():
             numerics.append(n)
             numerics.append(n)
 
 
     for c0, c1 in combinations(numerics, 2):
     for c0, c1 in combinations(numerics, 2):
-        p.append(P(node, add_numerics, (scope, c0, c1)))
+        p.append(P(node, add_numerics, (scope.clone(), c0, c1)))
 
 
     return p
     return p
 
 
@@ -186,10 +186,10 @@ def match_multiply_numerics(node):
             p.append(P(node, multiply_zero, (n,)))
             p.append(P(node, multiply_zero, (n,)))
 
 
         if n.value == 1:
         if n.value == 1:
-            p.append(P(node, multiply_one, (scope, n)))
+            p.append(P(node, multiply_one, (scope.clone(), n)))
 
 
     for c0, c1 in combinations(numerics, 2):
     for c0, c1 in combinations(numerics, 2):
-        p.append(P(node, multiply_numerics, (scope, c0, c1)))
+        p.append(P(node, multiply_numerics, (scope.clone(), c0, c1)))
 
 
     return p
     return p
 
 

+ 2 - 1
src/rules/powers.py

@@ -43,7 +43,8 @@ def match_add_exponents(node):
         # create a single power with that root
         # create a single power with that root
         if len(occurrences) > 1:
         if len(occurrences) > 1:
             for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
             for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
-                p.append(P(node, add_exponents, (scope, n0, n1, a0, e1, e2)))
+                p.append(P(node, add_exponents,
+                    (scope.clone(), n0, n1, a0, e1, e2)))
 
 
     return p
     return p
 
 

+ 2 - 2
src/rules/sort.py

@@ -123,7 +123,7 @@ def match_sort_monomial(node):
 
 
     scope = Scope(node)
     scope = Scope(node)
 
 
-    return [P(node, swap_factors, (scope, l, r))
+    return [P(node, swap_factors, (scope.clone(), l, r))
             for l, r in filter(swap_mono, iter_pairs(scope))]
             for l, r in filter(swap_mono, iter_pairs(scope))]
 
 
 
 
@@ -137,7 +137,7 @@ def match_sort_polynome(node):
 
 
     scope = Scope(node)
     scope = Scope(node)
 
 
-    return [P(node, swap_factors, (scope, l, r))
+    return [P(node, swap_factors, (scope.clone(), l, r))
             for l, r in filter(swap_poly, iter_pairs(scope))]
             for l, r in filter(swap_poly, iter_pairs(scope))]
 
 
 
 

+ 4 - 2
src/rules/sqrt.py

@@ -57,9 +57,11 @@ def match_reduce_sqrt(node):
 
 
         for n in scope:
         for n in scope:
             if is_eliminateable_sqrt(n):
             if is_eliminateable_sqrt(n):
-                p.append(P(node, extract_sqrt_mult_priority, (scope, n)))
+                p.append(P(node, extract_sqrt_mult_priority,
+                    (scope.clone(), n)))
             else:
             else:
-                p.append(P(node, extract_sqrt_multiplicant, (scope, n)))
+                p.append(P(node, extract_sqrt_multiplicant,
+                    (scope.clone(), n)))
 
 
         return p
         return p