Commit 0eadc533 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented a new, lightweight way of validation as described in the TODO list.

parent 0efcd75c
...@@ -382,6 +382,11 @@ class Parser(BisonParser): ...@@ -382,6 +382,11 @@ class Parser(BisonParser):
return self.root_node return self.root_node
def rewrite_and_count_all(self, check_implicit=True, verbose=False):
steps = self.rewrite_all(include_steps=True,
check_implicit=check_implicit, verbose=verbose)
return self.root_node, len(steps)
#def hook_run(self, filename, retval): #def hook_run(self, filename, retval):
# return retval # return retval
......
...@@ -18,45 +18,43 @@ from strategy import find_possibilities ...@@ -18,45 +18,43 @@ from strategy import find_possibilities
from tests.parser import ParserWrapper from tests.parser import ParserWrapper
def validate(exp, result): VALIDATE_FAILURE = 0
VALIDATE_NOPROGRESS = 1
VALIDATE_SUCCESS = 2
VALIDATE_ERROR = 3
def validate(a, b):
""" """
Validate that exp =>* result. Validate that a =>* b.
""" """
parser = ParserWrapper(Parser) parser = ParserWrapper(Parser)
exp = parser.run([exp]) # Parse both expressions
result = parser.run([result]) a = parser.run([a])
b = parser.run([b])
# Compare the simplified expressions first, in order to avoid the # Evaluate a and b, counting the number of steps
# computational intensive traversal of the possibilities tree. parser.set_root_node(a)
parser.set_root_node(exp) A, a_steps = parser.rewrite_and_count_all()
a = parser.rewrite_all()
if not a: if not a:
return False return VALIDATE_ERROR
parser.set_root_node(result)
b = parser.rewrite_all()
if not a.equals(b):
return False
# TODO: make sure cycles are avoided / eliminated using cycle detection. parser.set_root_node(b)
def traverse_preorder(node, result): B, b_steps = parser.rewrite_and_count_all()
#print 'node:', node, 'result:', result
if node.equals(result):
return True
for p in find_possibilities(node): if not B:
# Clone the root node because it will be used in multiple return VALIDATE_ERROR
# substitutions
temp = node.clone()
child = apply_suggestion(node, p)
node = temp
if traverse_preorder(child, result): # Evaluations must be equal
return True if not A.equals(B):
return VALIDATE_FAILURE
return False # If evaluation of b took more staps than evaluation of a, the step from a
# to b was probably useless or even bad
if b_steps >= a_steps:
return VALIDATE_NOPROGRESS
return traverse_preorder(exp, result) # Evaluations match and b is evaluated quicker than a => success
return VALIDATE_SUCCESS
...@@ -13,38 +13,42 @@ ...@@ -13,38 +13,42 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from unittest import TestCase from unittest import TestCase
from src.validation import validate from src.validation import validate, VALIDATE_SUCCESS as OK, \
VALIDATE_FAILURE as FAIL, VALIDATE_NOPROGRESS as NP
class TestValidation(TestCase): class TestValidation(TestCase):
def test_simple_success(self): def test_simple_success(self):
self.assertTrue(validate('3a + a', '4a')) self.assertEqual(validate('3a + a', '4a'), OK)
def test_simple_failure(self): def test_simple_failure(self):
self.assertFalse(validate('3a + a', '4a + 1')) self.assertEqual(validate('3a + a', '4a + 1'), FAIL)
def test_intermediate_success(self): def test_intermediate_success(self):
self.assertTrue(validate('3a + a + b + 2b', '4a + 3b')) self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'), OK)
self.assertTrue(validate('a / b / (c / d)', '(ad) / (bc)')) self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'), OK)
def test_intermediate_failure(self): def test_intermediate_failure(self):
self.assertFalse(validate('3a + a + b + 2b', '4a + 4b')) self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'), FAIL)
#def test_success(self): def test_success(self):
# self.assertTrue(validate('x^2 + x - 2x^2 + 3x + 1', self.assertEqual(validate('x^2 + x - 2x^2 + 3x + 1',
# 'x^2 + 4x - 2x^2 + 1')) 'x^2 + 4x - 2x^2 + 1'), OK)
#def test_indefinite_integral(self): def test_indefinite_integral(self):
# self.assertTrue(validate('int_2^4 x^2', '4^3/3 - 2^3/3')) self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'), OK)
#def test_advanced_failure(self): def test_advanced_failure(self):
# self.assertFalse(validate('(x-1)^3+(x-1)^3', '4a+4b')) self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'), FAIL)
def test_sphere_volume(self): def test_sphere_volume(self):
self.assertTrue(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
'4 / 3 * pi * r ^ 3')) '4 / 3 * pi * r ^ 3'), OK)
#def test_sphere_volume_alternative(self): def test_sphere_volume_alternative_notation(self):
# self.assertTrue(validate('int_(-r)^(r) pi * (r^2 - x^2) dx', self.assertEqual(validate('int_(-r)^(r) pi * (r^2 - x^2) dx',
# '4 * pi * r ^ 3 / 3')) '4 * pi * r ^ 3 / 3'), OK)
def test_noprogress_simple(self):
self.assertEqual(validate('2 + 2', '3 + 1'), NP)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment