Skip to content
Snippets Groups Projects
Commit 47ed0625 authored by Taddeüs Kroes's avatar Taddeüs Kroes
Browse files

Implemented an optimization in new validation implementation.

parent f6d30399
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from parser import Parser from parser import Parser, MAXIMUM_REWRITE_STEPS
from possibilities import apply_suggestion from possibilities import apply_suggestion
from strategy import find_possibilities from strategy import find_possibilities
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
...@@ -26,7 +26,7 @@ VALIDATE_ERROR = 3 ...@@ -26,7 +26,7 @@ VALIDATE_ERROR = 3
def validate(a, b): def validate(a, b):
""" """
Validate that a =>* b. Validate that a => b.
""" """
parser = ParserWrapper(Parser) parser = ParserWrapper(Parser)
...@@ -34,13 +34,29 @@ def validate(a, b): ...@@ -34,13 +34,29 @@ def validate(a, b):
a = parser.run([a]) a = parser.run([a])
b = parser.run([b]) b = parser.run([b])
if a.equals(b):
return VALIDATION_NOPROGRESS
# Evaluate a and b, counting the number of steps # Evaluate a and b, counting the number of steps
# TODO: Optimization: if b is encountered while evaluating a, return # Optimization: if b is encountered while evaluating a, return
# VALIDATION_SUCCESS
parser.set_root_node(a) parser.set_root_node(a)
A, a_steps = parser.rewrite_and_count_all() A = a
a_steps = 0
for i in xrange(MAXIMUM_REWRITE_STEPS):
obj = parser.rewrite()
if not obj:
break
# If b is some reduction of a, it will be detected here
if obj.equals(b):
return VALIDATE_SUCCESS
A = obj
a_steps += 1
if not a: if not A:
return VALIDATE_ERROR return VALIDATE_ERROR
parser.set_root_node(b) parser.set_root_node(b)
......
...@@ -17,7 +17,8 @@ import doctest ...@@ -17,7 +17,8 @@ import doctest
from src.node import ExpressionNode from src.node import ExpressionNode
from src.parser import Parser from src.parser import Parser
from src.validation import validate from src.validation import validate, VALIDATE_SUCCESS, \
VALIDATE_FAILURE, VALIDATE_NOPROGRESS
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
...@@ -92,6 +93,14 @@ class RulesTestCase(unittest.TestCase): ...@@ -92,6 +93,14 @@ class RulesTestCase(unittest.TestCase):
raise raise
def assertValidate(self, exp, result): def assertValidateSuccess(self, a, b):
self.assertTrue(validate(exp, result), self.assertEqual(validate(a, b), VALIDATE_SUCCESS,
'Validation failed: %s !=> %s') 'Validation failed: %s !=> %s' % (a, b))
def assertValidateFailure(self, a, b):
self.assertEqual(validate(a, b), VALIDATE_FAILURE,
'Validation dit not fail: %s => %s' % (a, b))
def assertValidateNoprogress(self, a, b):
self.assertEqual(validate(a, b), VALIDATE_NOPROGRESS, 'Validation '
'did detect progress or failed for %s => %s' % (a, b))
...@@ -12,43 +12,43 @@ ...@@ -12,43 +12,43 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from unittest import TestCase from rulestestcase import RulesTestCase
from src.validation import validate, VALIDATE_SUCCESS as OK, \ from src.validation import validate, VALIDATE_SUCCESS as OK, \
VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP
class TestValidation(TestCase): class TestValidation(RulesTestCase):
def test_simple_success(self): def test_simple_success(self):
self.assertEqual(validate('3a + a', '4a'), OK) self.assertValidateSuccess('3a + a', '4a')
def test_simple_failure(self): def test_simple_failure(self):
self.assertEqual(validate('3a + a', '4a + 1'), FAIL) self.assertValidateFailure('3a + a', '4a + 1')
def test_intermediate_success(self): def test_intermediate_success(self):
self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'), OK) self.assertValidateSuccess('3a + a + b + 2b', '4a + 3b')
self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'), OK) self.assertValidateSuccess('a / b / (c / d)', '(ad) / (bc)')
def test_intermediate_failure(self): def test_intermediate_failure(self):
self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'), FAIL) self.assertValidateFailure('3a + a + b + 2b', '4a + 4b')
def test_success(self): def test_success(self):
self.assertEqual(validate('x^2 + x - 2x^2 + 3x + 1', self.assertValidateSuccess('x^2 + x - 2x^2 + 3x + 1',
'x^2 + 4x - 2x^2 + 1'), OK) 'x^2 + 4x - 2x^2 + 1')
def test_indefinite_integral(self): def test_indefinite_integral(self):
self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'), OK) self.assertValidateSuccess('int_2^4 x^2', '4^3/3 - 2^3/3')
def test_advanced_failure(self): def test_advanced_failure(self):
self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'), FAIL) self.assertValidateFailure('(x-1)^3+(x-1)^3', '4a+4b')
def test_sphere_volume(self): def test_sphere_volume(self):
self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
'4 / 3 * pi * r ^ 3'), OK) '4 / 3 * pi * r ^ 3')
def test_sphere_volume_alternative_notation(self): def test_sphere_volume_alternative_notation(self):
self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertValidateSuccess('int_(-r)^(r) pi * (r^2 - x^2) dx',
'4 * pi * r ^ 3 / 3'), OK) '4 * pi * r ^ 3 / 3')
def test_noprogress_simple(self): def test_noprogress_simple(self):
self.assertEqual(validate('2 + 2', '3 + 1'), NP) self.assertValidateNoprogress('2 + 2', '3 + 1')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment