Commit a156a2b5 authored by Taddeus Kroes's avatar Taddeus Kroes

Added fucntion to check if a node contains another node.

parent dd613cc3
......@@ -50,8 +50,9 @@ OP_HINT = 20
OP_REWRITE_ALL = 21
OP_REWRITE = 22
# Special identifierd
# Special identifiers
PI = 'pi'
E = 'e'
TYPE_MAP = {
......@@ -178,7 +179,7 @@ class ExpressionBase(object):
and (identifier == None or self.value == identifier)
def is_variable(self):
return self.type == TYPE_IDENTIFIER and self.value != PI
return self.type == TYPE_IDENTIFIER and self.value not in (PI, E)
def is_int(self):
return self.type == TYPE_INTEGER
......@@ -217,6 +218,20 @@ class ExpressionBase(object):
"""Negate the node n times."""
return negate(self, self.negated + n)
def contains(self, node, include_self=True):
"""
Check if a node equal to the specified one exists within this node.
"""
if include_self and self == node:
return True
if not self.is_leaf:
for child in self:
if child.contains(node, include_self=True):
return True
return False
class ExpressionNode(Node, ExpressionBase):
def __init__(self, *args, **kwargs):
......
......@@ -209,3 +209,12 @@ class TestNode(RulesTestCase):
n = tree('-(a + b)')
self.assertEqualNodes(Scope(n).as_nary_node(), n)
self.assertEqualNodes(Scope(-n).as_nary_node(), -n)
def test_contains(self):
a, ab, bc, ln0, ln1 = tree('a, ab, bc, ln(a) + 1, ln(b) + 1')
self.assertTrue(a.contains(a))
self.assertTrue(ab.contains(a))
self.assertFalse(bc.contains(a))
self.assertTrue(ln0.contains(a))
self.assertFalse(ln1.contains(a))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment