Validate using breadth-first search (this is an experimental feature).

parent fb74c3df
......@@ -29,7 +29,7 @@ import re
# Rewriting an expression is stopped after this number of steps is passed.
MAXIMUM_REWRITE_STEPS = 30
MAXIMUM_REWRITE_STEPS = 20
# Check for n-ary operator in child nodes
......
from parser import Parser
from parser import Parser, MAXIMUM_REWRITE_STEPS
from possibilities import apply_suggestion
from strategy import find_possibilities
from tests.parser import ParserWrapper
#def traverse_breadth_first(node, result, depth=0):
# if depth > MAXIMUM_REWRITE_STEPS:
# #raise RuntimeError('MAXIMUM_REWRITE_STEPS is reached.')
# return
#
# if depth > 15:
# print '%3d %-30s -> %-30s' % (depth, str(node), str(result))
#
# children = []
# possibilities = find_possibilities(node)
#
# for p, possibility in enumerate(possibilities):
# # Clone the root node because it will be used in multiple
# # substitutions
# child = apply_suggestion(node.clone(), possibility)
#
# if child.equals(result):
# return child
#
# children.append(child)
#
# # If the final expression is not found in the direct children,
# # start searching in the children of the children.
# for c, child in enumerate(children):
# child_step = traverse_breadth_first(child, result, depth + 1)
#
# if child_step:
# return child_step
from collections import deque, defaultdict
def traverse_breadth_first(root, result, max_iterations=1e4):
queue = deque([root, 0])
i = 0
print 'root:', root, 'result:', result
counter = defaultdict(int)
while queue:
if i > max_iterations:
print 'unique:', len(counter)
print '\n'.join(sorted(map(str, counter.iteritems())))
raise RuntimeError('max_iterations is reached.')
i += 1
node = queue.popleft()
counter[str(node)] += 1
if node == 0:
print 'next depth: i = %d' % i
if queue:
queue.append(0)
continue
if node.equals(result):
return node
queue.extend([apply_suggestion(node.clone(), p) for p \
in find_possibilities(node)])
REWRITE_INVALID = 0 # Invalid original expression (e.g. syntax error)
REWRITE_FAILURE = 1 # One step failed to reduce to the original's reduction.
REWRITE_SUCCESS = 2 # Steps are valid, but one or more are not in the tree.
REWRITE_CHECKED = 3 # Steps are valid, and all are in the possibility tree.
def validate_new(original, *steps):
"""
Validate that original expression can (in)directly be rewritten as steps_0,
steps_1, ..., steps_n, and in that order.
"""
## TODO: make sure cycles are avoided / eliminated using cycle detection.
parser = ParserWrapper(Parser)
original = parser.run([original])
original_reduced = parser.rewrite_all()
if not original_reduced:
return REWRITE_INVALID
node = original
traversal = True
for s, step in enumerate(steps):
# Compare the simplified expressions first, in order to avoid the
# computational intensive traversal of the possibilities tree.
step = parser.run([step])
step_reduced = parser.rewrite_all()
if not original_reduced.equals(step_reduced):
return REWRITE_FAILURE, s
if traversal:
if not traverse_breadth_first(node, step):
traversal = False
node = step
if traversal:
return REWRITE_CHECKED, s
return REWRITE_SUCCESS, s
def validate(exp, result):
"""
Validate that exp =>* result.
......
from unittest import TestCase
from src.validation import validate
from src.validation import validate, REWRITE_SUCCESS, REWRITE_FAILURE, \
REWRITE_INVALID, REWRITE_CHECKED
class TestValidation(TestCase):
def test_simple_success(self):
self.assertTrue(validate('3a + a', '4a'))
# TODO: test REWRITE_INVALID (because a BisonSyntaxError is thrown now).
#def test_INVALID(self):
# self.assertEqual(validate('3a +', '3a'), REWRITE_INVALID)
def test_simple_CHECKED(self):
self.assertEqual(validate('3a + a', '4a'), (REWRITE_CHECKED, 0))
def test_simple_failure(self):
self.assertFalse(validate('3a + a', '4a + 1'))
self.assertEqual(validate('3a + a', '4a + 1'), (REWRITE_FAILURE, 0))
def test_intermediate_success(self):
self.assertTrue(validate('3a + a + b + 2b', '4a + 3b'))
self.assertTrue(validate('a / b / (c / d)', '(ad) / (bc)'))
self.assertEqual(validate('3a + a + b + 2b', '4a + 3b'),
(REWRITE_CHECKED, 0))
self.assertEqual(validate('a / b / (c / d)', '(ad) / (bc)'),
(REWRITE_CHECKED, 0))
def test_intermediate_failure(self):
self.assertFalse(validate('3a + a + b + 2b', '4a + 4b'))
self.assertEqual(validate('3a + a + b + 2b', '4a + 4b'),
(REWRITE_FAILURE, 0))
#def test_indefinite_integral(self):
# self.assertTrue(validate('int_2^4 x^2', '4^3/3 - 2^3/3'))
# TODO: this test fails due 'maximum recursion depth exceeded'.
def test_indefinite_integral(self):
self.assertEqual(validate('int_2^4 x^2', '4^3/3 - 2^3/3'),
(REWRITE_SUCCESS, 0))
#def test_advanced_failure(self):
# self.assertFalse(validate('(x-1)^3+(x-1)^3', '4a+4b'))
def test_advanced_failure(self):
self.assertEqual(validate('(x-1)^3+(x-1)^3', '4a+4b'),
(REWRITE_FAILURE, 0))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment