Переглянути джерело

Implemented subtree substitution.

Once a possibility is applied, the new subtree will be substituted into the
parent node of the subtree. This requires a subtree_map to avoid traversing the
expression tree. The subtree can also be the root node of the expression tree.
In that case, there is no parent node found for the subtree. Otherwise the old
subtree will be substituted with the new subtree node in the parent node.
Also updated the test cases accordingly.
Sander Mathijs van Veen 14 роки тому
батько
коміт
d9b28ea2b5
5 змінених файлів з 118 додано та 22 видалено
  1. 7 0
      src/node.py
  2. 44 8
      src/parser.py
  3. 29 2
      src/possibilities.py
  4. 13 12
      tests/test_leiden_oefenopgave.py
  5. 25 0
      tests/test_rewrite.py

+ 7 - 0
src/node.py

@@ -1,6 +1,7 @@
 # vim: set fileencoding=utf-8 :
 import os.path
 import sys
+import copy
 
 sys.path.insert(0, os.path.realpath('external'))
 
@@ -59,6 +60,9 @@ def to_expression(obj):
 
 
 class ExpressionBase(object):
+    def clone(self):
+        return copy.deepcopy(self)
+
     def __lt__(self, other):
         """
         Comparison between this expression{node,leaf} and another
@@ -161,6 +165,9 @@ class ExpressionNode(Node, ExpressionBase):
 
         return False
 
+    def substitute(self, old_child, new_child):
+        self.nodes[self.nodes.index(old_child)] = new_child
+
     def graph(self):  # pragma: nocover
         return generate_graph(self)
 

+ 44 - 8
src/parser.py

@@ -79,6 +79,9 @@ class Parser(BisonParser):
         self.read_buffer = ''
         self.read_queue = Queue.Queue()
 
+        self.subtree_map = {}
+        self.root_node = None
+
     # Override default read method with a version that prompts for input.
     def read(self, nbytes):
         if self.file == sys.stdin and self.file.closed:
@@ -181,11 +184,43 @@ class Parser(BisonParser):
                 or retval.type != TYPE_OPERATOR or retval.op not in RULES:
             return retval
 
+        # Update the subtree map to let the subtree point to its parent node.
+        parent_nodes = self.subtree_map.keys()
+
+        for child in retval:
+            if child in parent_nodes:
+                self.subtree_map[child] = retval
+
         for handler in RULES[retval.op]:
-            self.possibilities.extend(handler(retval))
+            possibilities = handler(retval)
+
+            # Record the subtree root node in order to avoid tree traversal.
+            # At this moment, the node is the root node since the expression is
+            # parser using the left-innermost parsing strategy.
+            for p in possibilities:
+                self.subtree_map[p.root] = None
+
+            self.possibilities.extend(possibilities)
 
         return retval
 
+    def display_hint(self):
+        print pick_suggestion(self.last_possibilities)
+
+    def display_possibilities(self):
+        print '\n'.join(map(str, self.last_possibilities))
+
+    def rewrite(self):
+        suggestion = pick_suggestion(self.last_possibilities)
+
+        if not suggestion:
+            return self.root_node
+
+        expression = apply_suggestion(self.root_node, self.subtree_map,
+                                    suggestion)
+        self.read_queue.put_nowait(str(expression))
+        return expression
+
     #def hook_run(self, filename, retval):
     #    return retval
 
@@ -226,22 +261,23 @@ class Parser(BisonParser):
              | REWRITE NEWLINE
              | RAISE NEWLINE
         """
-        if option in [1, 2]:  # rule: EXP NEWLINE | DEBUG NEWLINE
+        if option == 1:  # rule: EXP NEWLINE
+            self.root_node = values[0]
+            return values[0]
+
+        if option == 2:  # rule: DEBUG NEWLINE
             return values[0]
 
         if option == 3:  # rule: HINT NEWLINE
-            print pick_suggestion(self.last_possibilities)
+            self.display_hint()
             return
 
         if option == 4:  # rule: POSSIBILITIES NEWLINE
-            print '\n'.join(map(str, self.last_possibilities))
+            self.display_possibilities()
             return
 
         if option == 5:  # rule: REWRITE NEWLINE
-            suggestion = pick_suggestion(self.last_possibilities)
-            expression = apply_suggestion(suggestion)
-            self.read_queue.put_nowait(str(expression))
-            return expression
+            return self.rewrite()
 
         if option == 6:
             raise RuntimeError('on_line: exception raised')

+ 29 - 2
src/possibilities.py

@@ -53,10 +53,37 @@ def filter_duplicates(possibilities):
 
 
 def pick_suggestion(possibilities):
+    if not possibilities:
+        return
+
     # TODO: pick the best suggestion.
     suggestion = 0
     return possibilities[suggestion]
 
 
-def apply_suggestion(suggestion):
-    return suggestion.handler(suggestion.root, suggestion.args)
+def apply_suggestion(root, subtree_map, suggestion):
+    # clone the root node before modifying. After deep copying the root node,
+    # the subtree_map cannot be used since the hash() of each node in the deep
+    # copied root node has changed.
+    #root_clone = root.clone()
+
+    subtree = suggestion.handler(suggestion.root, suggestion.args)
+
+    if suggestion.root in subtree_map:
+        parent_node = subtree_map[suggestion.root]
+    else:
+        parent_node = None
+
+    # There is either a parent node or the subtree is the root node.
+    # FIXME: FAIL: test_diagnostic_test_application in tests/test_b1_ch08.py
+    #try:
+    #    assert bool(parent_node) != (subtree == root)
+    #except:
+    #    print 'parent_node: %s' % (str(parent_node))
+    #    print 'subtree: %s == %s' % (str(subtree), str(root))
+    #    raise
+
+    if parent_node:
+        parent_node.substitute(suggestion.root, subtree)
+        return root
+    return subtree

+ 13 - 12
tests/test_leiden_oefenopgave.py

@@ -4,8 +4,8 @@ from src.parser import Parser
 from tests.parser import ParserWrapper
 
 
-def reduce(exp, **kwargs):
-    return ParserWrapper(Parser, **kwargs).run([exp]).reduce()
+def rewrite(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
 
 
 class TestLeidenOefenopgave(TestCase):
@@ -19,7 +19,7 @@ class TestLeidenOefenopgave(TestCase):
                 ('-2(6x-4)^2*x',         '-72 * x^3 + 96 * x ^ 2 + 32 * x'),
                 ('(4x + 5) * -(5 - 4x)', '16x^2 - 25'),
                 ]:
-            self.assertEqual(str(reduce(exp)), solution)
+            self.assertEqual(str(rewrite(exp)), solution)
 
     def test_2(self):
         pass
@@ -28,14 +28,15 @@ class TestLeidenOefenopgave(TestCase):
         pass
 
     def test_4(self):
-        return
         for exp, solution in [
-                ('2/15 + 1/4',      '23/60'),
-                ('2/7 - 4/11',      '-6/77'),
-                ('(7/3) * (3/5)',   '7/5'),
-                ('(3/4) / (5/6)',   '9/10'),
-                ('1/4 * 1/x',       '1/(4x)'),
-                ('(3/x^2) / (x/7)', '21/x^3'),
-                ('1/x + 2/(x+1)',   '(3x + 1) / (x * (x + 1))'),
+                ('2/15 + 1/4',      '8 / 60 + 15 / 60'),
+                ('8/60 + 15/60',    '(8 + 15) / 60'),
+                ('(8 + 15) / 60',   '23 / 60'),
+                # FIXME: ('2/7 - 4/11',      '-6 / 77'),
+                # FIXME: ('(7/3) * (3/5)',   '7 / 5'),
+                # FIXME: ('(3/4) / (5/6)',   '9 / 10'),
+                # FIXME: ('1/4 * 1/x',       '1 / (4x)'),
+                # FIXME: ('(3/x^2) / (x/7)', '21 / x^3'),
+                # FIXME: ('1/x + 2/(x+1)',   '(3x + 1) / (x * (x + 1))'),
                 ]:
-            self.assertEqual(str(reduce(exp)), solution)
+            self.assertEqual(str(rewrite(exp)), solution)

+ 25 - 0
tests/test_rewrite.py

@@ -0,0 +1,25 @@
+from unittest import TestCase
+
+from src.parser import Parser
+from tests.parser import ParserWrapper
+
+
+def rewrite(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
+
+
+class TestRewrite(TestCase):
+
+    def assertRewrite(self, rewrite_chain):
+        try:
+            for i, exp in enumerate(rewrite_chain[:-1]):
+                self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
+        except AssertionError:
+            print exp, '->', rewrite_chain[i+1]
+            raise
+
+    def test_addition_rewrite(self):
+        self.assertRewrite(['2 + 3 + 4', '5 + 4', '9'])
+
+    def test_addition_identifiers_rewrite(self):
+        self.assertRewrite(['2 + 3a + 4', '6 + 3a'])