Browse Source

Changed negate() function to negate by reference by default (as a small optimization).

Taddeus Kroes 13 years ago
parent
commit
08c982bba4
3 changed files with 15 additions and 10 deletions
  1. 12 7
      src/node.py
  2. 2 2
      src/rules/groups.py
  3. 1 1
      src/rules/powers.py

+ 12 - 7
src/node.py

@@ -251,13 +251,13 @@ class ExpressionBase(object):
 
     def negate(self, n=1):
         """Negate the node n times."""
-        return negate(self, self.negated + n)
+        return negate(self, self.negated + n, clone=True)
 
     def contains(self, node, include_self=True):
         """
         Check if a node equal to the specified one exists within this node.
         """
-        if include_self and negate(self, 0) == node:
+        if include_self and self.equals(node, ignore_negation=True):
             return True
 
         if not self.is_leaf:
@@ -620,14 +620,19 @@ def get_scope(node):
     return scope
 
 
-def negate(node, n=1):
-    """Negate the given node n times."""
+def negate(node, n=1, clone=False):
+    """
+    Negate the given node n times. If clone is set to true, return a new node
+    so that the original node is not altered.
+    """
     assert n >= 0
 
-    new_node = node.clone()
-    new_node.negated = n
+    if clone:
+        node = node.clone()
+
+    node.negated = n
 
-    return new_node
+    return node
 
 
 def infinity():

+ 2 - 2
src/rules/groups.py

@@ -59,8 +59,8 @@ def match_combine_groups(node):
             # Move negations to constants
             c0 = c0.negate(g0.negated)
             c1 = c1.negate(g1.negated)
-            g0 = negate(g0, 0)
-            g1 = negate(g1, 0)
+            g0 = negate(g0, 0, clone=True)
+            g1 = negate(g1, 0, clone=True)
 
             p.append(P(node, combine_groups, (scope, c0, g0, n0, c1, g1, n1)))
 

+ 1 - 1
src/rules/powers.py

@@ -24,7 +24,7 @@ def match_add_exponents(node):
         # Order powers by their roots, e.g. a^p and a^q are put in the same
         # list because of the mutual 'a'
         if n.is_identifier():
-            s = negate(n, 0)
+            s = negate(n, 0, clone=True)
             exponent = L(1)
         elif n.is_op(OP_POW):
             s, exponent = n