Переглянути джерело

Moved scope class to node.py.

Taddeus Kroes 14 роки тому
батько
коміт
9fee079187
4 змінених файлів з 100 додано та 93 видалено
  1. 71 19
      src/node.py
  2. 0 34
      src/scope.py
  3. 29 4
      tests/test_node.py
  4. 0 36
      tests/test_scope.py

+ 71 - 19
src/node.py

@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
             return (self[0], self[1], ExpressionLeaf(1))
         return (self[1], self[0], ExpressionLeaf(1))
 
-    def get_scope(self):
-        """
-        Find all n nodes within the n-ary scope of this operator.
-        """
-        scope = []
-        #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
-
-        # TODO: what to do with OP_SUB and OP_ADD in get_scope?
-
-        for child in self:
-            if not child.is_leaf() and child.op == self.op:
-                scope += child.get_scope()
-            else:
-                scope.append(child)
-
-        return scope
-
     def equals(self, other):
         """
         Perform a non-strict equivalence check between two nodes:
@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
             return False
 
         if self.op in (OP_ADD, OP_MUL):
-            s0 = self.get_scope()
-            s1 = set(other.get_scope())
+            s0 = Scope(self)
+            s1 = set(Scope(other))
 
             # Scopes sould be of equal size
             if len(s0) != len(s1):
@@ -354,3 +337,72 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         """
         # rule: 1 * r ^ 1 -> (1, r, 1)
         return (ExpressionLeaf(1), self, ExpressionLeaf(1))
+
+
+class Scope(object):
+
+    def __init__(self, node):
+        self.node = node
+        self.nodes = get_scope(node)
+
+    def __getitem__(self, key):
+        return self.nodes[key]
+
+    def __setitem__(self, key, value):
+        self.nodes[key] = value
+
+    def __len__(self):
+        return len(self.nodes)
+
+    def __iter__(self):
+        return iter(self.nodes)
+
+    def remove(self, node, replacement=None):
+        if node.is_leaf():
+            node_cmp = hash(node)
+        else:
+            node_cmp = node
+
+        for i, n in enumerate(self.nodes):
+            if n.is_leaf():
+                n_cmp = hash(n)
+            else:
+                n_cmp = n
+
+            if n_cmp == node_cmp:
+                if replacement != None:
+                    self[i] = replacement
+                else:
+                    del self.nodes[i]
+
+                return
+
+        raise ValueError('Node "%s" is not in the scope of "%s".'
+                         % (node, self.node))
+
+    def as_nary_node(self):
+        return nary_node(self.node.value, self.nodes)
+
+
+def nary_node(operator, scope):
+    """
+    Create a binary expression tree for an n-ary operator. Takes the operator
+    and a list of expression nodes as arguments.
+    """
+    return scope[0] if len(scope) == 1 \
+           else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
+
+
+def get_scope(node):
+    """
+    Find all n nodes within the n-ary scope of an operator node.
+    """
+    scope = []
+
+    for child in node:
+        if child.is_op(node.op):
+            scope += get_scope(child)
+        else:
+            scope.append(child)
+
+    return scope

+ 0 - 34
src/scope.py

@@ -1,34 +0,0 @@
-from rules.utils import nary_node
-
-
-class Scope(object):
-
-    def __init__(self, node):
-        self.node = node
-        self.nodes = node.get_scope()
-
-    def remove(self, node, replacement=None):
-        if node.is_leaf():
-            node_cmp = hash(node)
-        else:
-            node_cmp = node
-
-        for i, n in enumerate(self.nodes):
-            if n.is_leaf():
-                n_cmp = hash(n)
-            else:
-                n_cmp = n
-
-            if n_cmp == node_cmp:
-                if replacement != None:
-                    self.nodes[i] = replacement
-                else:
-                    del self.nodes[i]
-
-                return
-
-        raise ValueError('Node "%s" is not in the scope of "%s".'
-                         % (node, self.node))
-
-    def as_nary_node(self):
-        return nary_node(self.node.value, self.nodes)

+ 29 - 4
tests/test_node.py

@@ -1,13 +1,15 @@
-import unittest
-
-from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
+from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_ADD
 from tests.rulestestcase import tree
 
 
-class TestNode(unittest.TestCase):
+class TestNode(RulesTestCase):
 
     def setUp(self):
         self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
+        self.n, self.f = tree('a + b + cd,f')
+        (self.a, self.b), self.cd = self.n
+        self.c, self.d = self.cd
+        self.scope = Scope(self.n)
 
     def test___lt__(self):
         self.assertTrue(L(1) < L(2))
@@ -168,3 +170,26 @@ class TestNode(unittest.TestCase):
 
         m0, m1 = tree('-5 * -3,-5 * 6')
         self.assertFalse(m0.equals(m1))
+
+    def test_scope___init__(self):
+        self.assertEqual(self.scope.node, self.n)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
+
+    def test_scope_remove_leaf(self):
+        self.scope.remove(self.b)
+        self.assertEqual(self.scope.nodes, [self.a, self.cd])
+
+    def test_scope_remove_node(self):
+        self.scope.remove(self.cd)
+        self.assertEqual(self.scope.nodes, [self.a, self.b])
+
+    def test_scope_remove_replace(self):
+        self.scope.remove(self.cd, self.f)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
+
+    def test_scope_remove_error(self):
+        with self.assertRaises(ValueError):
+            self.scope.remove(self.f)
+
+    def test_scope_as_nary_node(self):
+        self.assertEqualNodes(self.scope.as_nary_node(), self.n)

+ 0 - 36
tests/test_scope.py

@@ -1,36 +0,0 @@
-import unittest
-
-from src.scope import Scope
-from tests.rulestestcase import RulesTestCase, tree
-
-
-class TestScope(RulesTestCase):
-
-    def setUp(self):
-        self.n, self.f = tree('a + b + cd,f')
-        (self.a, self.b), self.cd = self.n
-        self.c, self.d = self.cd
-        self.scope = Scope(self.n)
-
-    def test___init__(self):
-        self.assertEqual(self.scope.node, self.n)
-        self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
-
-    def test_remove_leaf(self):
-        self.scope.remove(self.b)
-        self.assertEqual(self.scope.nodes, [self.a, self.cd])
-
-    def test_remove_node(self):
-        self.scope.remove(self.cd)
-        self.assertEqual(self.scope.nodes, [self.a, self.b])
-
-    def test_remove_replace(self):
-        self.scope.remove(self.cd, self.f)
-        self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
-
-    def test_remove_error(self):
-        with self.assertRaises(ValueError):
-            self.scope.remove(self.f)
-
-    def test_as_nary_node(self):
-        self.assertEqualNodes(self.scope.as_nary_node(), self.n)