Clone the scope using deepcopy, if used more than once.

parent a0c79c18
......@@ -612,12 +612,17 @@ class Scope(object):
return nary_node(self.node.op, nodes).negate(self.node.negated)
def clone(self):
return copy.deepcopy(self)
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
assert scope
if len(scope) == 1:
return scope[0]
......
......@@ -28,7 +28,8 @@ def match_factor_out_abs_term(node):
if exp.is_op(OP_MUL):
scope = Scope(exp)
return [P(node, factor_out_abs_term, (scope, n)) for n in scope]
return [P(node, factor_out_abs_term, (scope.clone(), n))
for n in scope]
if exp.is_op(OP_SQRT):
return [P(node, factor_out_abs_sqrt)]
......
......@@ -117,7 +117,8 @@ def match_const_deriv_multiplication(node):
for n in scope:
if not n.contains(x):
p.append(P(node, const_deriv_multiplication, (scope, n, x)))
p.append(P(node, const_deriv_multiplication,
(scope.clone(), n, x)))
return p
......@@ -345,7 +346,7 @@ def match_sum_product_rule(node):
else:
handler = sum_rule
return [P(node, handler, (scope, f)) for f in functions]
return [P(node, handler, (scope.clone(), f)) for f in functions]
def sum_rule(root, args):
......
......@@ -89,13 +89,14 @@ def match_add_fractions(node):
if b == d:
# Equal denominators, add nominators to create a single fraction
p.append(P(node, add_nominators, (scope, ab, cd)))
p.append(P(node, add_nominators, (scope.clone(), ab, cd)))
elif all(map(is_numeric_node, (a, b, c, d))):
# Denominators are both numeric, rewrite both fractions to the
# least common multiple of their denominators. Later, the
# nominators will be added
lcm = least_common_multiple(b.value, d.value)
p.append(P(node, equalize_denominators, (scope, ab, cd, lcm)))
p.append(P(node, equalize_denominators,
(scope.clone(), ab, cd, lcm)))
# Also, add the (non-recommended) possibility to multiply the
# denominators. Do this only if the multiplication is not equal to
......@@ -103,7 +104,8 @@ def match_add_fractions(node):
mult = b.value * d.value
if mult != lcm:
p.append(P(node, equalize_denominators, (scope, ab, cd, mult)))
p.append(P(node, equalize_denominators,
(scope.clone(), ab, cd, mult)))
for ab, c in product(fractions, numerics):
a, b = ab
......@@ -111,7 +113,7 @@ def match_add_fractions(node):
if a.is_numeric() and b.is_numeric():
# Fraction of constants added to a constant -> create a single
# constant fraction
p.append(P(node, constant_to_fraction, (scope, ab, c)))
p.append(P(node, constant_to_fraction, (scope.clone(), ab, c)))
return p
......@@ -189,11 +191,11 @@ def match_multiply_fractions(node):
fractions, others = partition(lambda n: n.is_op(OP_DIV), scope)
for ab, cd in combinations(fractions, 2):
p.append(P(node, multiply_fractions, (scope, ab, cd)))
p.append(P(node, multiply_fractions, (scope.clone(), ab, cd)))
for ab, c in product(fractions, others):
if evals_to_numeric(c) or not evals_to_numeric(ab):
p.append(P(node, multiply_with_fraction, (scope, ab, c)))
p.append(P(node, multiply_with_fraction, (scope.clone(), ab, c)))
return p
......
......@@ -25,10 +25,10 @@ def match_add_quadrants(node):
continue
if not sin_q.negated and not cos_q.negated:
p.append(P(node, add_quadrants, (scope, sin_q, cos_q)))
p.append(P(node, add_quadrants, (scope.clone(), sin_q, cos_q)))
elif sin_q.negated == 1 and cos_q.negated == 1:
p.append(P(node, factor_out_quadrant_negation,
(scope, sin_q, cos_q)))
(scope.clone(), sin_q, cos_q)))
return p
......
......@@ -54,7 +54,8 @@ def match_combine_groups(node):
c1 = c1.negate(n1.negated)
if g0.equals(g1):
p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
p.append(P(node, combine_groups,
(scope.clone(), c0, g0, n0, c1, g1, n1)))
elif g0.equals(g1, ignore_negation=True):
# Move negations to constants
c0 = c0.negate(g0.negated)
......@@ -62,7 +63,8 @@ def match_combine_groups(node):
g0 = negate(g0, 0, clone=True)
g1 = negate(g1, 0, clone=True)
p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
p.append(P(node, combine_groups,
(scope.clone(), c0, g0, n0, c1, g1, n1)))
return p
......
......@@ -184,7 +184,7 @@ def match_factor_out_constant(node):
for n in scope:
if not n.contains(x):
p.append(P(node, factor_out_constant, (scope, n)))
p.append(P(node, factor_out_constant, (scope.clone(), n)))
return p
......@@ -327,9 +327,9 @@ def match_sum_rule_integral(node):
scope = Scope(node[0])
if len(scope) == 2:
return [P(node, sum_rule_integral, (scope, scope[0]))]
return [P(node, sum_rule_integral, (scope.clone(), scope[0]))]
return [P(node, sum_rule_integral, (scope, n)) for n in scope]
return [P(node, sum_rule_integral, (scope.clone(), n)) for n in scope]
def sum_rule_integral(root, args):
......@@ -360,7 +360,8 @@ def match_remove_indef_constant(node):
x = find_variable(node[0])
constants = [n for n in scope if not n.contains(x)]
return [P(node, remove_indef_constant, (scope, c)) for c in constants]
return [P(node, remove_indef_constant, (scope.clone(), c))
for c in constants]
def remove_indef_constant(root, args):
......
......@@ -177,7 +177,8 @@ def match_multiple_equations(node):
# Substitution rule
if x.is_variable() and eq1.contains(x):
p.append(P(node, substitute_variable, (scope, x, subs, eq1)))
p.append(P(node, substitute_variable,
(scope.clone(), x, subs, eq1)))
return p
......@@ -208,7 +209,7 @@ def match_double_case(node):
for a, b in combinations(scope, 2):
if a == b:
p.append(P(node, double_case, (scope, a, b)))
p.append(P(node, double_case, (scope.clone(), a, b)))
return p
......
......@@ -99,16 +99,18 @@ def match_add_logarithms(node):
if not log_a.negated and not log_b.negated:
# log(a) + log(b) -> log(ab)
p.append(P(node, add_logarithms, (scope, log_a, log_b)))
p.append(P(node, add_logarithms, (scope.clone(), log_a, log_b)))
elif a_negated and b_negated:
# -log(a) - log(b) -> -(log(a) + log(b))
p.append(P(node, expand_negations, (scope, log_a, log_b)))
p.append(P(node, expand_negations, (scope.clone(), log_a, log_b)))
elif not log_a.negated and b_negated and divides(b.value, a.value):
# log(a) - log(b) -> log(a / b)
p.append(P(node, subtract_logarithms, (scope, log_a, log_b)))
p.append(P(node, subtract_logarithms,
(scope.clone(), log_a, log_b)))
elif a_negated and not log_b.negated and divides(a.value, b.value):
# -log(a) + log(b) -> log(b / a)
p.append(P(node, subtract_logarithms, (scope, log_b, log_a)))
p.append(P(node, subtract_logarithms,
(scope.clone(), log_b, log_a)))
return p
......@@ -187,7 +189,7 @@ def match_raised_base(node):
# Add this possibility so that a 'raised_base' possibility is
# generated in the following iteration
p.append(P(node, factor_in_exponent_multiplicant,
(scope, other, log)))
(scope.clone(), other, log)))
return p
......@@ -309,7 +311,8 @@ def match_factor_in_multiplicant(node):
p = []
for constant, logarithm in product(constants, logarithms):
p.append(P(node, factor_in_multiplicant, (scope, constant, logarithm)))
p.append(P(node, factor_in_multiplicant,
(scope.clone(), constant, logarithm)))
return p
......@@ -342,7 +345,7 @@ def match_expand_terms(node):
if exp.is_op(OP_MUL):
scope = Scope(exp)
return [P(node, expand_multiplication_terms, (scope, n)) \
return [P(node, expand_multiplication_terms, (scope.clone(), n)) \
for n in ifilterfalse(is_numeric_node, scope)]
if exp.is_op(OP_DIV):
......
......@@ -19,7 +19,7 @@ def match_negated_factor(node):
for factor in scope:
if factor.negated:
p.append(P(node, negated_factor, (scope, factor)))
p.append(P(node, negated_factor, (scope.clone(), factor)))
return p
......
......@@ -27,12 +27,12 @@ def match_add_numerics(node):
for n in scope:
if n == 0:
p.append(P(node, remove_zero, (scope, n)))
p.append(P(node, remove_zero, (scope.clone(), n)))
elif n.is_numeric():
numerics.append(n)
for c0, c1 in combinations(numerics, 2):
p.append(P(node, add_numerics, (scope, c0, c1)))
p.append(P(node, add_numerics, (scope.clone(), c0, c1)))
return p
......@@ -186,10 +186,10 @@ def match_multiply_numerics(node):
p.append(P(node, multiply_zero, (n,)))
if n.value == 1:
p.append(P(node, multiply_one, (scope, n)))
p.append(P(node, multiply_one, (scope.clone(), n)))
for c0, c1 in combinations(numerics, 2):
p.append(P(node, multiply_numerics, (scope, c0, c1)))
p.append(P(node, multiply_numerics, (scope.clone(), c0, c1)))
return p
......
......@@ -43,7 +43,8 @@ def match_add_exponents(node):
# create a single power with that root
if len(occurrences) > 1:
for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
p.append(P(node, add_exponents, (scope, n0, n1, a0, e1, e2)))
p.append(P(node, add_exponents,
(scope.clone(), n0, n1, a0, e1, e2)))
return p
......
......@@ -123,7 +123,7 @@ def match_sort_monomial(node):
scope = Scope(node)
return [P(node, swap_factors, (scope, l, r))
return [P(node, swap_factors, (scope.clone(), l, r))
for l, r in filter(swap_mono, iter_pairs(scope))]
......@@ -137,7 +137,7 @@ def match_sort_polynome(node):
scope = Scope(node)
return [P(node, swap_factors, (scope, l, r))
return [P(node, swap_factors, (scope.clone(), l, r))
for l, r in filter(swap_poly, iter_pairs(scope))]
......
......@@ -57,9 +57,11 @@ def match_reduce_sqrt(node):
for n in scope:
if is_eliminateable_sqrt(n):
p.append(P(node, extract_sqrt_mult_priority, (scope, n)))
p.append(P(node, extract_sqrt_mult_priority,
(scope.clone(), n)))
else:
p.append(P(node, extract_sqrt_multiplicant, (scope, n)))
p.append(P(node, extract_sqrt_multiplicant,
(scope.clone(), n)))
return p
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment