Commit 47ed0625 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented an optimization in new validation implementation.

parent f6d30399
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from parser import Parser from parser import Parser, MAXIMUM_REWRITE_STEPS
from possibilities import apply_suggestion from possibilities import apply_suggestion
from strategy import find_possibilities from strategy import find_possibilities
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
...@@ -26,7 +26,7 @@ VALIDATE_ERROR = 3 ...@@ -26,7 +26,7 @@ VALIDATE_ERROR = 3
def validate(a, b): def validate(a, b):
""" """
Validate that a =>* b. Validate that a => b.
""" """
parser = ParserWrapper(Parser) parser = ParserWrapper(Parser)
...@@ -34,13 +34,29 @@ def validate(a, b): ...@@ -34,13 +34,29 @@ def validate(a, b):
a = parser.run([a]) a = parser.run([a])
b = parser.run([b]) b = parser.run([b])
if a.equals(b):
return VALIDATION_NOPROGRESS
# Evaluate a and b, counting the number of steps # Evaluate a and b, counting the number of steps
# TODO: Optimization: if b is encountered while evaluating a, return # Optimization: if b is encountered while evaluating a, return
# VALIDATION_SUCCESS
parser.set_root_node(a) parser.set_root_node(a)
A, a_steps = parser.rewrite_and_count_all() A = a
a_steps = 0
for i in xrange(MAXIMUM_REWRITE_STEPS):
obj = parser.rewrite()
if not obj:
break
# If b is some reduction of a, it will be detected here
if obj.equals(b):
return VALIDATE_SUCCESS
A = obj
a_steps += 1
if not a: if not A:
return VALIDATE_ERROR return VALIDATE_ERROR
parser.set_root_node(b) parser.set_root_node(b)
......
...@@ -17,7 +17,8 @@ import doctest ...@@ -17,7 +17,8 @@ import doctest
from src.node import ExpressionNode from src.node import ExpressionNode
from src.parser import Parser from src.parser import Parser
from src.validation import validate from src.validation import validate, VALIDATE_SUCCESS, \
VALIDATE_FAILURE, VALIDATE_NOPROGRESS
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
...@@ -92,6 +93,14 @@ class RulesTestCase(unittest.TestCase): ...@@ -92,6 +93,14 @@ class RulesTestCase(unittest.TestCase):
raise raise
def assertValidate(self, exp, result): def assertValidateSuccess(self, a, b):
self.assertTrue(validate(exp, result), self.assertEqual(validate(a, b), VALIDATE_SUCCESS,
'Validation failed: %s !=> %s') 'Validation failed: %s !=> %s' % (a, b))
def assertValidateFailure(self, a, b):
self.assertEqual(validate(a, b), VALIDATE_FAILURE,
'Validation dit not fail: %s => %s' % (a, b))
def assertValidateNoprogress(self, a, b):
self.assertEqual(validate(a, b), VALIDATE_NOPROGRESS, 'Validation '
'did detect progress or failed for %s => %s' % (a, b))
...@@ -12,43 +12,43 @@ ...@@ -12,43 +12,43 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from unittest import TestCase from rulestestcase import RulesTestCase
from src.validation import validate, VALIDATE_SUCCESS as OK, \ from src.validation import validate, VALIDATE_SUCCESS as OK, \
VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP
class TestValidation(TestCase): class TestValidation(RulesTestCase):
def test_simple_success(self): def test_simple_success(self):
self.assertEqual(validate('3a + a', '4a'), OK) self.assertValidateSuccess('3a + a', '4a')
def test_simple_failure(self): def test_simple_failure(self):
self.assertEqual(validate('3a + a', '4a + 1'), FAIL) self.assertValidateFailure('3a + a', '4a + 1')
def test_intermediate_success(self): def test_intermediate_success(self):
self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'), OK) self.assertValidateSuccess('3a + a + b + 2b', '4a + 3b')
self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'), OK) self.assertValidateSuccess('a / b / (c / d)', '(ad) / (bc)')
def test_intermediate_failure(self): def test_intermediate_failure(self):
self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'), FAIL) self.assertValidateFailure('3a + a + b + 2b', '4a + 4b')
def test_success(self): def test_success(self):
self.assertEqual(validate('x^2 + x - 2x^2 + 3x + 1', self.assertValidateSuccess('x^2 + x - 2x^2 + 3x + 1',
'x^2 + 4x - 2x^2 + 1'), OK) 'x^2 + 4x - 2x^2 + 1')
def test_indefinite_integral(self): def test_indefinite_integral(self):
self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'), OK) self.assertValidateSuccess('int_2^4 x^2', '4^3/3 - 2^3/3')
def test_advanced_failure(self): def test_advanced_failure(self):
self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'), FAIL) self.assertValidateFailure('(x-1)^3+(x-1)^3', '4a+4b')
def test_sphere_volume(self): def test_sphere_volume(self):
self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
'4 / 3 * pi * r ^ 3'), OK) '4 / 3 * pi * r ^ 3')
def test_sphere_volume_alternative_notation(self): def test_sphere_volume_alternative_notation(self):
self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
'4 * pi * r ^ 3 / 3'), OK) '4 * pi * r ^ 3 / 3')
def test_noprogress_simple(self): def test_noprogress_simple(self):
self.assertEqual(validate('2 + 2', '3 + 1'), NP) self.assertValidateNoprogress('2 + 2', '3 + 1')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment