Browse Source

Added fucntion to check if a node contains another node.

Taddeus Kroes 14 years ago
parent
commit
a156a2b585
2 changed files with 26 additions and 2 deletions
  1. 17 2
      src/node.py
  2. 9 0
      tests/test_node.py

+ 17 - 2
src/node.py

@@ -50,8 +50,9 @@ OP_HINT = 20
 OP_REWRITE_ALL = 21
 OP_REWRITE = 22
 
-# Special identifierd
+# Special identifiers
 PI = 'pi'
+E = 'e'
 
 
 TYPE_MAP = {
@@ -178,7 +179,7 @@ class ExpressionBase(object):
                and (identifier == None or self.value == identifier)
 
     def is_variable(self):
-        return self.type == TYPE_IDENTIFIER and self.value != PI
+        return self.type == TYPE_IDENTIFIER and self.value not in (PI, E)
 
     def is_int(self):
         return self.type == TYPE_INTEGER
@@ -217,6 +218,20 @@ class ExpressionBase(object):
         """Negate the node n times."""
         return negate(self, self.negated + n)
 
+    def contains(self, node, include_self=True):
+        """
+        Check if a node equal to the specified one exists within this node.
+        """
+        if include_self and self == node:
+            return True
+
+        if not self.is_leaf:
+            for child in self:
+                if child.contains(node, include_self=True):
+                    return True
+
+        return False
+
 
 class ExpressionNode(Node, ExpressionBase):
     def __init__(self, *args, **kwargs):

+ 9 - 0
tests/test_node.py

@@ -209,3 +209,12 @@ class TestNode(RulesTestCase):
         n = tree('-(a + b)')
         self.assertEqualNodes(Scope(n).as_nary_node(), n)
         self.assertEqualNodes(Scope(-n).as_nary_node(), -n)
+
+    def test_contains(self):
+        a, ab, bc, ln0, ln1 = tree('a, ab, bc, ln(a) + 1, ln(b) + 1')
+
+        self.assertTrue(a.contains(a))
+        self.assertTrue(ab.contains(a))
+        self.assertFalse(bc.contains(a))
+        self.assertTrue(ln0.contains(a))
+        self.assertFalse(ln1.contains(a))