Skip to content
Snippets Groups Projects
Commit 9fee0791 authored by Taddeus Kroes's avatar Taddeus Kroes
Browse files

Moved scope class to node.py.

parent 5383845b
No related branches found
No related tags found
No related merge requests found
......@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
return (self[0], self[1], ExpressionLeaf(1))
return (self[1], self[0], ExpressionLeaf(1))
def get_scope(self):
"""
Find all n nodes within the n-ary scope of this operator.
"""
scope = []
#op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
# TODO: what to do with OP_SUB and OP_ADD in get_scope?
for child in self:
if not child.is_leaf() and child.op == self.op:
scope += child.get_scope()
else:
scope.append(child)
return scope
def equals(self, other):
"""
Perform a non-strict equivalence check between two nodes:
......@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
return False
if self.op in (OP_ADD, OP_MUL):
s0 = self.get_scope()
s1 = set(other.get_scope())
s0 = Scope(self)
s1 = set(Scope(other))
# Scopes sould be of equal size
if len(s0) != len(s1):
......@@ -354,3 +337,72 @@ class ExpressionLeaf(Leaf, ExpressionBase):
"""
# rule: 1 * r ^ 1 -> (1, r, 1)
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = get_scope(node)
def __getitem__(self, key):
return self.nodes[key]
def __setitem__(self, key, value):
self.nodes[key] = value
def __len__(self):
return len(self.nodes)
def __iter__(self):
return iter(self.nodes)
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
return scope[0] if len(scope) == 1 \
else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
def get_scope(node):
"""
Find all n nodes within the n-ary scope of an operator node.
"""
scope = []
for child in node:
if child.is_op(node.op):
scope += get_scope(child)
else:
scope.append(child)
return scope
from rules.utils import nary_node
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = node.get_scope()
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self.nodes[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_ADD
from tests.rulestestcase import tree
class TestNode(unittest.TestCase):
class TestNode(RulesTestCase):
def setUp(self):
self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___lt__(self):
self.assertTrue(L(1) < L(2))
......@@ -168,3 +170,26 @@ class TestNode(unittest.TestCase):
m0, m1 = tree('-5 * -3,-5 * 6')
self.assertFalse(m0.equals(m1))
def test_scope___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_scope_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_scope_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_scope_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_scope_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_scope_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
import unittest
from src.scope import Scope
from tests.rulestestcase import RulesTestCase, tree
class TestScope(RulesTestCase):
def setUp(self):
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment