Fixed merge conflict, added match_extend_exponent and improved add_exponents.

parent 5383845b
......@@ -19,3 +19,46 @@
- rewrite match_combine_polynomes to an even more generic form:
match_combine_factors.
- Fix division by zero caused by "0/0".
smvv@multivac ~/work/trs $ printf "a/0\n??" | ./main.py
Traceback (most recent call last):
File "./main.py", line 75, in <module>
main()
File "./main.py", line 64, in main
node = p.run(debug=args.debug)
File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 258, in run
self.report_last_error(filename, e)
File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 251, in run
self.engine.runEngine(debug)
File "bison_.pyx", line 592, in bison_.ParserEngine.runEngine (build/external/pybison/bison_.c:592)
File "/home/smvv/work/trs/src/parser.py", line 195, in hook_handler
possibilities = handler(retval)
File "/home/smvv/work/trs/src/rules/fractions.py", line 23, in match_constant_division
raise ZeroDivisionError('Division by zero: %s.' % node)
ZeroDivisionError: Division by zero: a / 0.
smvv@multivac ~/work/trs $ printf "0/0\n??" | ./main.py
Traceback (most recent call last):
File "./main.py", line 75, in <module>
main()
File "./main.py", line 64, in main
node = p.run(debug=args.debug)
File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 258, in run
self.report_last_error(filename, e)
File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 251, in run
self.engine.runEngine(debug)
File "bison_.pyx", line 592, in bison_.ParserEngine.runEngine (build/external/pybison/bison_.c:592)
File "/home/smvv/work/trs/src/parser.py", line 195, in hook_handler
possibilities = handler(retval)
File "/home/smvv/work/trs/src/rules/numerics.py", line 73, in match_divide_numerics
divide = not divmod(n.value, dv)[1]
ZeroDivisionError: integer division or modulo by zero
- Last possibilities reduce to a similar result.
smvv@multivac ~/work/trs $ printf "0/1\n??" | ./main.py
<Possibility root="0 / 1" handler=divide_numerics args=(0, 1)>
Division of 0 by 1 reduces to 0.
Division of 0 by 1 reduces to 0.
......@@ -4,7 +4,8 @@ from .groups import match_combine_groups
from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \
match_multiply_exponents, match_duplicate_exponent, \
match_remove_negative_exponent, match_exponent_to_root
match_remove_negative_exponent, match_exponent_to_root, \
match_extend_exponent
from .numerics import match_divide_numerics, match_multiply_numerics
from .fractions import match_constant_division, match_add_constant_fractions, \
match_expand_and_add_fractions
......@@ -18,5 +19,6 @@ RULES = {
OP_DIV: [match_subtract_exponents, match_divide_numerics, \
match_constant_division],
OP_POW: [match_multiply_exponents, match_duplicate_exponent, \
match_remove_negative_exponent, match_exponent_to_root],
match_remove_negative_exponent, match_exponent_to_root, \
match_extend_exponent],
}
......@@ -37,13 +37,16 @@ def match_combine_groups(node):
for i, sub_node in enumerate(scope):
if sub_node.is_numeric():
others = [scope[j] for j in range(i) + range(i + 1, l)]
g = others[0] if len(others) == 1 else Node('*', *others)
if len(others) == 1:
g = others[0]
else:
g = Node('*', *others)
groups.append((sub_node, g, n))
#print [map(str, group) for group in groups]
for g0, g1 in combinations(groups, 2):
if g0[1].equals(g1[1]):
#print type(g0[1]), str(g0[1]), 'equals', type(g1[1]), str(g1[1])
p.append(P(node, combine_groups, g0 + g1))
return p
......
from itertools import combinations
from ..node import ExpressionNode as N, ExpressionLeaf as L, \
OP_NEG, OP_MUL, OP_DIV, OP_POW
OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
from ..translate import _
......@@ -10,6 +10,9 @@ from ..translate import _
def match_add_exponents(node):
"""
a^p * a^q -> a^(p + q)
a * a^q -> a^(1 + q)
a^p * a -> a^(p + 1)
a * a -> a^(1 + 1)
"""
assert node.is_op(OP_MUL)
......@@ -17,26 +20,53 @@ def match_add_exponents(node):
powers = {}
for n in node.get_scope():
if n.is_op(OP_POW):
if n.is_identifier():
s = n
exponent = L(1)
elif n.is_op(OP_POW):
# Order powers by their roots, e.g. a^p and a^q are put in the same
# list because of the mutual 'a'
s = str(n[0])
s, exponent = n
else:
continue
s_str = str(s)
if s in powers:
powers[s].append(n)
if s_str in powers:
powers[s_str].append((n, exponent, s))
else:
powers[s] = [n]
powers[s_str] = [(n, exponent, s)]
for root, occurrences in powers.iteritems():
# If a root has multiple occurences, their exponents can be added to
# create a single power with that root
if len(occurrences) > 1:
for pair in combinations(occurrences, 2):
p.append(P(node, add_exponents, pair))
for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
p.append(P(node, add_exponents, (n0, n1, a0, e1, e2)))
return p
def add_exponents(root, args):
"""
a^p * a^q -> a^(p + q)
"""
n0, n1, a, p, q = args
scope = root.get_scope()
# Replace the left node with the new expression
scope[scope.index(n0)] = a ** (p + q)
# Remove the right node
scope.remove(n1)
return nary_node('*', scope)
MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
' will reduce to {1[0]}^({1[1]} + {2[1]}).')
def match_subtract_exponents(node):
"""
a^p / a^q -> a^(p - q)
......@@ -120,26 +150,32 @@ def match_exponent_to_root(node):
return []
def add_exponents(root, args):
def match_extend_exponent(node):
"""
a^p * a^q -> a^(p + q)
(a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
"""
n0, n1 = args
a, p = n0
q = n1[1]
scope = root.get_scope()
assert node.is_op(OP_POW)
# Replace the left node with the new expression
scope[scope.index(n0)] = a ** (p + q)
left, right = node
# Remove the right node
scope.remove(n1)
if right.is_numeric():
for n in node.get_scope():
if n.is_op(OP_ADD):
return [P(node, extend_exponent, (left, right))]
return nary_node('*', scope)
return []
MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
' will reduce to {1[0]}^({1[1]} + {2[1]}).')
def extend_exponent(root, args):
"""
(a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
"""
left, right = args
if right.value > 2:
return left * left ** L(right.value - 1)
return left * left
def subtract_exponents(root, args):
......
......@@ -43,7 +43,8 @@ class RulesTestCase(unittest.TestCase):
def assertRewrite(self, rewrite_chain):
try:
for i, exp in enumerate(rewrite_chain[:-1]):
self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
self.assertMultiLineEqual(str(rewrite(exp)),
str(rewrite_chain[i+1]))
except AssertionError: # pragma: nocover
print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
print 'rewrite chain:', rewrite_chain
......
......@@ -2,11 +2,14 @@ from tests.rulestestcase import RulesTestCase as TestCase, rewrite
class TestLeidenOefenopgave(TestCase):
def test_1(self):
def test_1_1(self):
for chain in [['-5(x2 - 3x + 6)', '-5(x ^ 2 - 3x) - 5 * 6',
'-5 * x ^ 2 - 5 * -3x - 5 * 6',
'-5 * x ^ 2 - -15x - 5 * 6',
# FIXME: '-5 * x ^ 2 - 5 * -3x - 30',
# FIXME: '-5 * x ^ 2 - -15x - 5 * 6',
# FIXME: '-5 * x ^ 2 + 15x - 5 * 6',
# FIXME: '-5 * x ^ 2 + 15x - 30',
], #'-30 + 15 * x - 5 * x ^ 2'],
]:
self.assertRewrite(chain)
......@@ -23,6 +26,56 @@ class TestLeidenOefenopgave(TestCase):
]:
self.assertEqual(str(rewrite(exp)), solution)
def test_1_2(self):
for chain in [['(x+1)^3', '(x + 1)(x + 1) ^ 2',
'(x + 1)(x + 1)(x + 1)',
'(xx + x * 1 + 1x + 1 * 1)(x + 1)',
'(x ^ (1 + 1) + x * 1 + 1x + 1 * 1)(x + 1)',
'(x ^ 2 + x * 1 + 1x + 1 * 1)(x + 1)',
'(x ^ 2 + (1 + 1)x + 1 * 1)(x + 1)',
'(x ^ 2 + 2x + 1 * 1)(x + 1)',
'(x ^ 2 + 2x + 1)(x + 1)',
'(x ^ 2 + 2x)x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x * x ^ 2 + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x ^ (1 + 2) + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x ^ 3 + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x ^ 3 + x ^ (1 + 1) * 2 + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x ^ 3 + x ^ 2 * 2 + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
'x ^ 3 + x ^ 2 * 2 + 1 * x ^ 2 + 1 * 2x + 1x + 1 * 1',
'x ^ 3 + (2 + 1) * x ^ 2 + 1 * 2x + 1x + 1 * 1',
'x ^ 3 + 3 * x ^ 2 + 1 * 2x + 1x + 1 * 1',
'x ^ 3 + 3 * x ^ 2 + 2x + 1x + 1 * 1',
'x ^ 3 + 3 * x ^ 2 + (2 + 1)x + 1 * 1',
'x ^ 3 + 3 * x ^ 2 + 3x + 1 * 1',
'x ^ 3 + 3 * x ^ 2 + 3x + 1',
]
]:
self.assertRewrite(chain)
def test_1_3(self):
# (x+1)^2 -> x^2 + 2x + 1
for chain in [['(x+1)^2', '(x + 1)(x + 1)',
'xx + x * 1 + 1x + 1 * 1',
'x ^ (1 + 1) + x * 1 + 1x + 1 * 1',
'x ^ 2 + x * 1 + 1x + 1 * 1',
'x ^ 2 + (1 + 1)x + 1 * 1',
'x ^ 2 + 2x + 1 * 1',
'x ^ 2 + 2x + 1'],
]:
self.assertRewrite(chain)
def test_1_4(self):
# (x-1)^2 -> x^2 - 2x + 1
for chain in [['(x-1)^2', '(x - 1)(x - 1)',
'xx + x * -1 - 1x - 1 * -1',
'x ^ (1 + 1) + x * -1 - 1x - 1 * -1',
'x ^ 2 + x * -1 - 1x - 1 * -1',
# FIXME: 'x ^ 2 + (-1 - 1)x - 1 * -1',
# FIXME: 'x ^ 2 - 2x - 1 * -1',
# FIXME: 'x ^ 2 - 2x + 1',
]]:
self.assertRewrite(chain)
def test_2(self):
pass
......
......@@ -17,7 +17,7 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities,
[P(root, add_exponents, (n0, n1))])
[P(root, add_exponents, (n0, n1, a, p, q))])
def test_match_add_exponents_ternary(self):
a, p, q, r = tree('a,p,q,r')
......@@ -25,9 +25,9 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities,
[P(root, add_exponents, (n0, n1)),
P(root, add_exponents, (n0, n2)),
P(root, add_exponents, (n1, n2))])
[P(root, add_exponents, (n0, n1, a, p, q)),
P(root, add_exponents, (n0, n2, a, p, r)),
P(root, add_exponents, (n1, n2, a, q, r))])
def test_match_add_exponents_multiple_identifiers(self):
a, b, p, q = tree('a,b,p,q')
......@@ -35,8 +35,8 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities,
[P(root, add_exponents, (a0, a1)),
P(root, add_exponents, (b0, b1))])
[P(root, add_exponents, (a0, a1, a, p, q)),
P(root, add_exponents, (b0, b1, b, p, q))])
def test_match_subtract_exponents_powers(self):
a, p, q = tree('a,p,q')
......@@ -103,7 +103,7 @@ class TestRulesPowers(RulesTestCase):
a, p, q = tree('a,p,q')
n0, n1 = root = a ** p * a ** q
self.assertEqualNodes(add_exponents(root, (n0, n1)), a ** (p + q))
self.assertEqualNodes(add_exponents(root, (n0, n1, a, p, q)), a ** (p + q))
def test_subtract_exponents(self):
a, p, q = tree('a,p,q')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment