Sfoglia il codice sorgente

Added Scope class for easy scope-level operations.

Taddeus Kroes 14 anni fa
parent
commit
5383845b30
2 ha cambiato i file con 70 aggiunte e 0 eliminazioni
  1. 34 0
      src/scope.py
  2. 36 0
      tests/test_scope.py

+ 34 - 0
src/scope.py

@@ -0,0 +1,34 @@
+from rules.utils import nary_node
+
+
+class Scope(object):
+
+    def __init__(self, node):
+        self.node = node
+        self.nodes = node.get_scope()
+
+    def remove(self, node, replacement=None):
+        if node.is_leaf():
+            node_cmp = hash(node)
+        else:
+            node_cmp = node
+
+        for i, n in enumerate(self.nodes):
+            if n.is_leaf():
+                n_cmp = hash(n)
+            else:
+                n_cmp = n
+
+            if n_cmp == node_cmp:
+                if replacement != None:
+                    self.nodes[i] = replacement
+                else:
+                    del self.nodes[i]
+
+                return
+
+        raise ValueError('Node "%s" is not in the scope of "%s".'
+                         % (node, self.node))
+
+    def as_nary_node(self):
+        return nary_node(self.node.value, self.nodes)

+ 36 - 0
tests/test_scope.py

@@ -0,0 +1,36 @@
+import unittest
+
+from src.scope import Scope
+from tests.rulestestcase import RulesTestCase, tree
+
+
+class TestScope(RulesTestCase):
+
+    def setUp(self):
+        self.n, self.f = tree('a + b + cd,f')
+        (self.a, self.b), self.cd = self.n
+        self.c, self.d = self.cd
+        self.scope = Scope(self.n)
+
+    def test___init__(self):
+        self.assertEqual(self.scope.node, self.n)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
+
+    def test_remove_leaf(self):
+        self.scope.remove(self.b)
+        self.assertEqual(self.scope.nodes, [self.a, self.cd])
+
+    def test_remove_node(self):
+        self.scope.remove(self.cd)
+        self.assertEqual(self.scope.nodes, [self.a, self.b])
+
+    def test_remove_replace(self):
+        self.scope.remove(self.cd, self.f)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
+
+    def test_remove_error(self):
+        with self.assertRaises(ValueError):
+            self.scope.remove(self.f)
+
+    def test_as_nary_node(self):
+        self.assertEqualNodes(self.scope.as_nary_node(), self.n)