Commit 9fee0791 authored by Taddeus Kroes's avatar Taddeus Kroes

Moved scope class to node.py.

parent 5383845b
......@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
return (self[0], self[1], ExpressionLeaf(1))
return (self[1], self[0], ExpressionLeaf(1))
def get_scope(self):
"""
Find all n nodes within the n-ary scope of this operator.
"""
scope = []
#op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
# TODO: what to do with OP_SUB and OP_ADD in get_scope?
for child in self:
if not child.is_leaf() and child.op == self.op:
scope += child.get_scope()
else:
scope.append(child)
return scope
def equals(self, other):
"""
Perform a non-strict equivalence check between two nodes:
......@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
return False
if self.op in (OP_ADD, OP_MUL):
s0 = self.get_scope()
s1 = set(other.get_scope())
s0 = Scope(self)
s1 = set(Scope(other))
# Scopes sould be of equal size
if len(s0) != len(s1):
......@@ -354,3 +337,72 @@ class ExpressionLeaf(Leaf, ExpressionBase):
"""
# rule: 1 * r ^ 1 -> (1, r, 1)
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = get_scope(node)
def __getitem__(self, key):
return self.nodes[key]
def __setitem__(self, key, value):
self.nodes[key] = value
def __len__(self):
return len(self.nodes)
def __iter__(self):
return iter(self.nodes)
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
return scope[0] if len(scope) == 1 \
else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
def get_scope(node):
"""
Find all n nodes within the n-ary scope of an operator node.
"""
scope = []
for child in node:
if child.is_op(node.op):
scope += get_scope(child)
else:
scope.append(child)
return scope
from rules.utils import nary_node
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = node.get_scope()
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self.nodes[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_ADD
from tests.rulestestcase import tree
class TestNode(unittest.TestCase):
class TestNode(RulesTestCase):
def setUp(self):
self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___lt__(self):
self.assertTrue(L(1) < L(2))
......@@ -168,3 +170,26 @@ class TestNode(unittest.TestCase):
m0, m1 = tree('-5 * -3,-5 * 6')
self.assertFalse(m0.equals(m1))
def test_scope___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_scope_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_scope_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_scope_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_scope_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_scope_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
import unittest
from src.scope import Scope
from tests.rulestestcase import RulesTestCase, tree
class TestScope(RulesTestCase):
def setUp(self):
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment