Răsfoiți Sursa

Validate using breadth-first search (this is an experimental feature).

Sander Mathijs van Veen 13 ani în urmă
părinte
comite
105b8dc41f
3 a modificat fișierele cu 137 adăugiri și 13 ștergeri
  1. 1 1
      src/parser.py
  2. 114 1
      src/validation.py
  3. 22 11
      tests/test_validation.py

+ 1 - 1
src/parser.py

@@ -29,7 +29,7 @@ import re
 
 
 # Rewriting an expression is stopped after this number of steps is passed.
-MAXIMUM_REWRITE_STEPS = 30
+MAXIMUM_REWRITE_STEPS = 20
 
 
 # Check for n-ary operator in child nodes

+ 114 - 1
src/validation.py

@@ -1,9 +1,122 @@
-from parser import Parser
+from parser import Parser, MAXIMUM_REWRITE_STEPS
 from possibilities import apply_suggestion
 from strategy import find_possibilities
 from tests.parser import ParserWrapper
 
 
+#def traverse_breadth_first(node, result, depth=0):
+#    if depth > MAXIMUM_REWRITE_STEPS:
+#        #raise RuntimeError('MAXIMUM_REWRITE_STEPS is reached.')
+#        return
+#
+#    if depth > 15:
+#        print '%3d %-30s -> %-30s' % (depth, str(node), str(result))
+#
+#    children = []
+#    possibilities = find_possibilities(node)
+#
+#    for p, possibility in enumerate(possibilities):
+#        # Clone the root node because it will be used in multiple
+#        # substitutions
+#        child = apply_suggestion(node.clone(), possibility)
+#
+#        if child.equals(result):
+#            return child
+#
+#        children.append(child)
+#
+#    # If the final expression is not found in the direct children,
+#    # start searching in the children of the children.
+#    for c, child in enumerate(children):
+#        child_step = traverse_breadth_first(child, result, depth + 1)
+#
+#        if child_step:
+#            return child_step
+
+
+from collections import deque, defaultdict
+
+
+def traverse_breadth_first(root, result, max_iterations=1e4):
+    queue = deque([root, 0])
+    i = 0
+
+    print 'root:', root, 'result:', result
+
+    counter = defaultdict(int)
+
+    while queue:
+        if i > max_iterations:
+            print 'unique:', len(counter)
+            print '\n'.join(sorted(map(str, counter.iteritems())))
+            raise RuntimeError('max_iterations is reached.')
+
+        i += 1
+
+        node = queue.popleft()
+
+        counter[str(node)] += 1
+
+        if node == 0:
+            print 'next depth: i = %d' % i
+
+            if queue:
+                queue.append(0)
+
+            continue
+
+        if node.equals(result):
+            return node
+
+        queue.extend([apply_suggestion(node.clone(), p) for p \
+                in find_possibilities(node)])
+
+
+REWRITE_INVALID = 0  # Invalid original expression (e.g. syntax error)
+REWRITE_FAILURE = 1  # One step failed to reduce to the original's reduction.
+REWRITE_SUCCESS = 2  # Steps are valid, but one or more are not in the tree.
+REWRITE_CHECKED = 3  # Steps are valid, and all are in the possibility tree.
+
+
+def validate_new(original, *steps):
+    """
+    Validate that original expression can (in)directly be rewritten as steps_0,
+    steps_1, ..., steps_n, and in that order.
+    """
+    ## TODO: make sure cycles are avoided / eliminated using cycle detection.
+    parser = ParserWrapper(Parser)
+
+    original = parser.run([original])
+    original_reduced = parser.rewrite_all()
+
+    if not original_reduced:
+        return REWRITE_INVALID
+
+    node = original
+
+    traversal = True
+
+    for s, step in enumerate(steps):
+        # Compare the simplified expressions first, in order to avoid the
+        # computational intensive traversal of the possibilities tree.
+        step = parser.run([step])
+        step_reduced = parser.rewrite_all()
+
+        if not original_reduced.equals(step_reduced):
+            return REWRITE_FAILURE, s
+
+        if traversal:
+            if not traverse_breadth_first(node, step):
+                traversal = False
+
+            node = step
+
+    if traversal:
+        return REWRITE_CHECKED, s
+
+    return REWRITE_SUCCESS, s
+
+
 def validate(exp, result):
     """
     Validate that exp =>* result.

+ 22 - 11
tests/test_validation.py

@@ -1,24 +1,35 @@
 from unittest import TestCase
-from src.validation import validate
+from src.validation import validate, REWRITE_SUCCESS, REWRITE_FAILURE, \
+        REWRITE_INVALID, REWRITE_CHECKED
 
 
 class TestValidation(TestCase):
 
-    def test_simple_success(self):
-        self.assertTrue(validate('3a + a', '4a'))
+    # TODO: test REWRITE_INVALID (because a BisonSyntaxError is thrown now).
+    #def test_INVALID(self):
+    #    self.assertEqual(validate('3a +', '3a'), REWRITE_INVALID)
+
+    def test_simple_CHECKED(self):
+        self.assertEqual(validate('3a + a', '4a'), (REWRITE_CHECKED, 0))
 
     def test_simple_failure(self):
-        self.assertFalse(validate('3a + a', '4a + 1'))
+        self.assertEqual(validate('3a + a', '4a + 1'), (REWRITE_FAILURE, 0))
 
     def test_intermediate_success(self):
-        self.assertTrue(validate('3a + a + b + 2b', '4a + 3b'))
-        self.assertTrue(validate('a / b / (c / d)', '(ad) / (bc)'))
+        self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'),
+                         (REWRITE_CHECKED, 0))
+        self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'),
+                         (REWRITE_CHECKED, 0))
 
     def test_intermediate_failure(self):
-        self.assertFalse(validate('3a + a + b + 2b', '4a + 4b'))
+        self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'),
+                         (REWRITE_FAILURE, 0))
 
-    #def test_indefinite_integral(self):
-    #    self.assertTrue(validate('int_2^4 x^2', '4^3/3 - 2^3/3'))
+    # TODO: this test fails due 'maximum recursion depth exceeded'.
+    def test_indefinite_integral(self):
+        self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'),
+                         (REWRITE_SUCCESS, 0))
 
-    #def test_advanced_failure(self):
-    #    self.assertFalse(validate('(x-1)^3+(x-1)^3', '4a+4b'))
+    def test_advanced_failure(self):
+        self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'),
+                         (REWRITE_FAILURE, 0))