Просмотр исходного кода

Implemented an optimization in new validation implementation.

Taddeus Kroes 13 лет назад
Родитель
Сommit
47ed06252d
3 измененных файлов с 51 добавлено и 26 удалено
  1. 22 6
      src/validation.py
  2. 13 4
      tests/rulestestcase.py
  3. 16 16
      tests/test_validation.py

+ 22 - 6
src/validation.py

@@ -12,7 +12,7 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with TRS.  If not, see <http://www.gnu.org/licenses/>.
-from parser import Parser
+from parser import Parser, MAXIMUM_REWRITE_STEPS
 from possibilities import apply_suggestion
 from strategy import find_possibilities
 from tests.parser import ParserWrapper
@@ -26,7 +26,7 @@ VALIDATE_ERROR = 3
 
 def validate(a, b):
     """
-    Validate that a =>* b.
+    Validate that a => b.
     """
     parser = ParserWrapper(Parser)
 
@@ -34,13 +34,29 @@ def validate(a, b):
     a = parser.run([a])
     b = parser.run([b])
 
+    if a.equals(b):
+        return VALIDATION_NOPROGRESS
+
     # Evaluate a and b, counting the number of steps
-    # TODO: Optimization: if b is encountered while evaluating a, return
-    # VALIDATION_SUCCESS
+    # Optimization: if b is encountered while evaluating a, return
     parser.set_root_node(a)
-    A, a_steps = parser.rewrite_and_count_all()
+    A = a
+    a_steps = 0
+
+    for i in xrange(MAXIMUM_REWRITE_STEPS):
+        obj = parser.rewrite()
+
+        if not obj:
+            break
+
+        # If b is some reduction of a, it will be detected here
+        if obj.equals(b):
+            return VALIDATE_SUCCESS
+
+        A = obj
+        a_steps += 1
 
-    if not a:
+    if not A:
         return VALIDATE_ERROR
 
     parser.set_root_node(b)

+ 13 - 4
tests/rulestestcase.py

@@ -17,7 +17,8 @@ import doctest
 
 from src.node import ExpressionNode
 from src.parser import Parser
-from src.validation import validate
+from src.validation import validate, VALIDATE_SUCCESS, \
+        VALIDATE_FAILURE, VALIDATE_NOPROGRESS
 from tests.parser import ParserWrapper
 
 
@@ -92,6 +93,14 @@ class RulesTestCase(unittest.TestCase):
 
             raise
 
-    def assertValidate(self, exp, result):
-        self.assertTrue(validate(exp, result),
-                        'Validation failed: %s  !=>  %s')
+    def assertValidateSuccess(self, a, b):
+        self.assertEqual(validate(a, b), VALIDATE_SUCCESS,
+                         'Validation failed: %s !=> %s' % (a, b))
+
+    def assertValidateFailure(self, a, b):
+        self.assertEqual(validate(a, b), VALIDATE_FAILURE,
+                         'Validation dit not fail: %s => %s' % (a, b))
+
+    def assertValidateNoprogress(self, a, b):
+        self.assertEqual(validate(a, b), VALIDATE_NOPROGRESS, 'Validation '
+                'did detect progress or failed for %s => %s' % (a, b))

+ 16 - 16
tests/test_validation.py

@@ -12,43 +12,43 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with TRS.  If not, see <http://www.gnu.org/licenses/>.
-from unittest import TestCase
+from rulestestcase import RulesTestCase
 from src.validation import validate, VALIDATE_SUCCESS as OK, \
         VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP
 
 
-class TestValidation(TestCase):
+class TestValidation(RulesTestCase):
 
     def test_simple_success(self):
-        self.assertEqual(validate('3a + a', '4a'), OK)
+        self.assertValidateSuccess('3a + a', '4a')
 
     def test_simple_failure(self):
-        self.assertEqual(validate('3a + a', '4a + 1'), FAIL)
+        self.assertValidateFailure('3a + a', '4a + 1')
 
     def test_intermediate_success(self):
-        self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'), OK)
-        self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'), OK)
+        self.assertValidateSuccess('3a + a + b + 2b', '4a + 3b')
+        self.assertValidateSuccess('a / b / (c / d)', '(ad) / (bc)')
 
     def test_intermediate_failure(self):
-        self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'), FAIL)
+        self.assertValidateFailure('3a + a + b + 2b', '4a + 4b')
 
     def test_success(self):
-        self.assertEqual(validate('x^2 + x - 2x^2 + 3x + 1',
-                                 'x^2 + 4x - 2x^2 + 1'), OK)
+        self.assertValidateSuccess('x^2 + x - 2x^2 + 3x + 1',
+                                 'x^2 + 4x - 2x^2 + 1')
 
     def test_indefinite_integral(self):
-        self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'), OK)
+        self.assertValidateSuccess('int_2^4 x^2', '4^3/3 - 2^3/3')
 
     def test_advanced_failure(self):
-        self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'), FAIL)
+        self.assertValidateFailure('(x-1)^3+(x-1)^3', '4a+4b')
 
     def test_sphere_volume(self):
-        self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
-                                  '4 / 3 * pi * r ^ 3'), OK)
+        self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
+                                  '4 / 3 * pi * r ^ 3')
 
     def test_sphere_volume_alternative_notation(self):
-        self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
-                                 '4 * pi * r ^ 3 / 3'), OK)
+        self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
+                                 '4 * pi * r ^ 3 / 3')
 
     def test_noprogress_simple(self):
-        self.assertEqual(validate('2 + 2', '3 + 1'), NP)
+        self.assertValidateNoprogress('2 + 2', '3 + 1')