Jelajahi Sumber

Implemented a new, lightweight way of validation as described in the TODO list.

Taddeus Kroes 13 tahun lalu
induk
melakukan
0eadc533f5
3 mengubah file dengan 55 tambahan dan 48 penghapusan
  1. 5 0
      src/parser.py
  2. 28 30
      src/validation.py
  3. 22 18
      tests/test_validation.py

+ 5 - 0
src/parser.py

@@ -382,6 +382,11 @@ class Parser(BisonParser):
 
             return self.root_node
 
+    def rewrite_and_count_all(self, check_implicit=True, verbose=False):
+        steps = self.rewrite_all(include_steps=True,
+                check_implicit=check_implicit, verbose=verbose)
+        return self.root_node, len(steps)
+
     #def hook_run(self, filename, retval):
     #    return retval
 

+ 28 - 30
src/validation.py

@@ -18,45 +18,43 @@ from strategy import find_possibilities
 from tests.parser import ParserWrapper
 
 
-def validate(exp, result):
+VALIDATE_FAILURE = 0
+VALIDATE_NOPROGRESS = 1
+VALIDATE_SUCCESS = 2
+VALIDATE_ERROR = 3
+
+
+def validate(a, b):
     """
-    Validate that exp =>* result.
+    Validate that a =>* b.
     """
     parser = ParserWrapper(Parser)
 
-    exp = parser.run([exp])
-    result = parser.run([result])
+    # Parse both expressions
+    a = parser.run([a])
+    b = parser.run([b])
 
-    # Compare the simplified expressions first, in order to avoid the
-    # computational intensive traversal of the possibilities tree.
-    parser.set_root_node(exp)
-    a = parser.rewrite_all()
+    # Evaluate a and b, counting the number of steps
+    parser.set_root_node(a)
+    A, a_steps = parser.rewrite_and_count_all()
 
     if not a:
-        return False
-
-    parser.set_root_node(result)
-    b = parser.rewrite_all()
-
-    if not a.equals(b):
-        return False
+        return VALIDATE_ERROR
 
-    # TODO: make sure cycles are avoided / eliminated using cycle detection.
-    def traverse_preorder(node, result):
-        #print 'node:', node, 'result:', result
-        if node.equals(result):
-            return True
+    parser.set_root_node(b)
+    B, b_steps = parser.rewrite_and_count_all()
 
-        for p in find_possibilities(node):
-            # Clone the root node because it will be used in multiple
-            # substitutions
-            temp = node.clone()
-            child = apply_suggestion(node, p)
-            node = temp
+    if not B:
+        return VALIDATE_ERROR
 
-            if traverse_preorder(child, result):
-                return True
+    # Evaluations must be equal
+    if not A.equals(B):
+        return VALIDATE_FAILURE
 
-        return False
+    # If evaluation of b took more staps than evaluation of a, the step from a
+    # to b was probably useless or even bad
+    if b_steps >= a_steps:
+        return VALIDATE_NOPROGRESS
 
-    return traverse_preorder(exp, result)
+    # Evaluations match and b is evaluated quicker than a => success
+    return VALIDATE_SUCCESS

+ 22 - 18
tests/test_validation.py

@@ -13,38 +13,42 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with TRS.  If not, see <http://www.gnu.org/licenses/>.
 from unittest import TestCase
-from src.validation import validate
+from src.validation import validate, VALIDATE_SUCCESS as OK, \
+        VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP
 
 
 class TestValidation(TestCase):
 
     def test_simple_success(self):
-        self.assertTrue(validate('3a + a', '4a'))
+        self.assertEqual(validate('3a + a', '4a'), OK)
 
     def test_simple_failure(self):
-        self.assertFalse(validate('3a + a', '4a + 1'))
+        self.assertEqual(validate('3a + a', '4a + 1'), FAIL)
 
     def test_intermediate_success(self):
-        self.assertTrue(validate('3a + a + b + 2b', '4a + 3b'))
-        self.assertTrue(validate('a / b / (c / d)', '(ad) / (bc)'))
+        self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'), OK)
+        self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'), OK)
 
     def test_intermediate_failure(self):
-        self.assertFalse(validate('3a + a + b + 2b', '4a + 4b'))
+        self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'), FAIL)
 
-    #def test_success(self):
-    #    self.assertTrue(validate('x^2 + x - 2x^2 + 3x + 1',
-    #                             'x^2 + 4x - 2x^2 + 1'))
+    def test_success(self):
+        self.assertEqual(validate('x^2 + x - 2x^2 + 3x + 1',
+                                 'x^2 + 4x - 2x^2 + 1'), OK)
 
-    #def test_indefinite_integral(self):
-    #    self.assertTrue(validate('int_2^4 x^2', '4^3/3 - 2^3/3'))
+    def test_indefinite_integral(self):
+        self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'), OK)
 
-    #def test_advanced_failure(self):
-    #    self.assertFalse(validate('(x-1)^3+(x-1)^3', '4a+4b'))
+    def test_advanced_failure(self):
+        self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'), FAIL)
 
     def test_sphere_volume(self):
-        self.assertTrue(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
-                                 '4 / 3 * pi * r ^ 3'))
+        self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
+                                  '4 / 3 * pi * r ^ 3'), OK)
 
-    #def test_sphere_volume_alternative(self):
-    #    self.assertTrue(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
-    #                             '4 * pi * r ^ 3 / 3'))
+    def test_sphere_volume_alternative_notation(self):
+        self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
+                                 '4 * pi * r ^ 3 / 3'), OK)
+
+    def test_noprogress_simple(self):
+        self.assertEqual(validate('2 + 2', '3 + 1'), NP)