Added support for OP_NEG in multiply_numerics and add_nominators.

parent ee315bc7
......@@ -90,6 +90,15 @@ class ExpressionBase(object):
# Self is a leaf, thus has less value than an expression node.
return True
if self.is_op(OP_NEG) and self[0].is_leaf():
if other.is_leaf():
# Both are leafs, string compare the value.
return ('-' + str(self.value)) < str(other.value)
if other.is_op(OP_NEG) and other[0].is_leaf():
return ('-' + str(self.value)) < ('-' + str(other.value))
# Self is a leaf, thus has less value than an expression node.
return True
if other.is_leaf():
# Self is an expression node, and the other is a leaf. Thus, other
# is greater than self.
......
from itertools import product, combinations
from .utils import nary_node
from ..node import OP_ADD, OP_MUL
from ..node import OP_ADD, OP_MUL, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -19,7 +19,7 @@ def match_expand(node):
additions = []
for n in node.get_scope():
if n.is_leaf():
if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
leaves.append(n)
elif n.op == OP_ADD:
additions.append(n)
......
......@@ -87,8 +87,15 @@ def match_add_constant_fractions(node):
fractions = filter(is_division, node.get_scope())
for a, b in combinations(fractions, 2):
na, da = a if a.is_op(OP_DIV) else a[0]
nb, db = b if b.is_op(OP_DIV) else b[0]
if a.is_op(OP_NEG):
na, da = a[0]
else:
na, da = a
if b.is_op(OP_NEG):
nb, db = b[0]
else:
nb, db = b
if da == db:
# Equal denominators, add nominators to create a single fraction
......@@ -134,17 +141,22 @@ MESSAGES[equalize_denominators] = _('Equalize the denominators of division'
def add_nominators(root, args):
"""
a / b + c / b -> (a + c) / b
a / b + (-c / b) -> (a + (-c)) / b
a / -b + c / -b -> (a + c) / -b
a / -b - c / -b -> (a - c) / -b
"""
# TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"?
ab, cb = args
if ab.is_op(OP_NEG):
a, b = ab[0]
else:
a, b = ab
if cb[0].is_op(OP_NEG):
c = cb[0][0]
substitution = (a + (-c)) / b
if cb.is_op(OP_NEG):
c = -cb[0][0]
else:
c = cb[0]
substitution = (a + c) / b
scope = root.get_scope()
......@@ -158,7 +170,9 @@ def add_nominators(root, args):
return nary_node('+', scope)
MESSAGES[add_nominators] = _('Add nominators of the division of {1} by {2}.')
# TODO: convert this to a lambda. Example: 22 / 77 - 28 / 77. the "-" is above
# the "28/77" division.
MESSAGES[add_nominators] = _('Add nominators {1[0]} and {2[0]} of the division.')
def match_expand_and_add_fractions(node):
......
......@@ -18,8 +18,15 @@ def add_numerics(root, args):
"""
n0, n1, c0, c1 = args
c0 = (-c0[0].value) if c0.is_op(OP_NEG) else c0.value
c1 = (-c1[0].value) if c1.is_op(OP_NEG) else c1.value
if c0.is_op(OP_NEG):
c0 = (-c0[0].value)
else:
c0 = c0.value
if c1.is_op(OP_NEG):
c1 = (-c1[0].value)
else:
c1 = c1.value
scope = root.get_scope()
......@@ -110,11 +117,16 @@ def match_multiply_numerics(node):
assert node.is_op(OP_MUL)
p = []
scope = node.get_scope()
numerics = filter(lambda n: n.is_numeric(), scope)
numerics = []
for n in node.get_scope():
if n.is_numeric():
numerics.append((n, n.value))
elif n.is_op(OP_NEG) and n[0].is_numeric():
numerics.append((n, n[0].value))
for args in combinations(numerics, 2):
p.append(P(node, multiply_numerics, args))
for (n0, v0), (n1, v1) in combinations(numerics, 2):
p.append(P(node, multiply_numerics, (n0, n1, v0, v1)))
return p
......@@ -126,13 +138,19 @@ def multiply_numerics(root, args):
Example:
2 * 3 -> 6
"""
n0, n1 = args
n0, n1, v0, v1 = args
scope = []
value = v0 * v1
if value > 0:
substitution = Leaf(value)
else:
substitution = -Leaf(-value)
for n in root.get_scope():
if hash(n) == hash(n0):
# Replace the left node with the new expression
scope.append(Leaf(n0.value * n1.value))
scope.append(substitution)
#scope.append(n)
elif hash(n) != hash(n1):
# Remove the right node
......
......@@ -79,7 +79,10 @@ def combine_polynomes(root, args):
n0, n1, c0, c1, r, e = args
# a ^ 1 -> a
power = r if e == 1 else r ** e
if e == 1:
power = r
else:
power = r ** e
# replacement: (c0 + c1) * a ^ b
# a, b and c are from 'left', d is from 'right'.
......
......@@ -8,6 +8,10 @@ def tree(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp])
def rewrite(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp, '@'])
class RulesTestCase(unittest.TestCase):
def assertEqualPos(self, possibilities, expected):
......@@ -35,3 +39,12 @@ class RulesTestCase(unittest.TestCase):
for ca, cb in zip(a, b):
self.assertEqualNodes(ca, cb)
def assertRewrite(self, rewrite_chain):
try:
for i, exp in enumerate(rewrite_chain[:-1]):
self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
except AssertionError: # pragma: nocover
print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
print 'rewrite chain:', rewrite_chain
raise
from unittest import TestCase
from src.parser import Parser
from tests.parser import ParserWrapper
def rewrite(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp, '@'])
from tests.rulestestcase import RulesTestCase as TestCase, rewrite
class TestLeidenOefenopgave(TestCase):
def test_1(self):
for chain in [['-5(x2 - 3x + 6)', '-5(x ^ 2 - 3x) - 5 * 6',
# FIXME: '-5 * x ^ 2 - 5 * -3x - 5 * 6',
# FIXME: '-5 * x ^ 2 - 5 * -3x - 30',
], #'-30 + 15 * x - 5 * x ^ 2'],
]:
self.assertRewrite(chain)
return
for exp, solution in [
('-5(x2 -3x + 6)', '-30 + 15 * x - 5 * x ^ 2'),
('-5(x2 - 3x + 6)', '-30 + 15 * x - 5 * x ^ 2'),
('(x+1)^2', 'x ^ 2 + 2 * x + 1'),
('(x-1)^2', 'x ^ 2 - 2 * x + 1'),
('(2x+x)*x', '3 * x ^ 2'),
......@@ -32,9 +33,9 @@ class TestLeidenOefenopgave(TestCase):
('2/15 + 1/4', '8 / 60 + 15 / 60'),
('8/60 + 15/60', '(8 + 15) / 60'),
('(8 + 15) / 60', '23 / 60'),
('2/7 - 4/11', '22 / 77 + -28 / 77'),
('22/77 + -28/77', '(22 + -28) / 77'),
('(22 + -28)/77', '-6 / 77'),
('2/7 - 4/11', '22 / 77 - 28 / 77'),
('22/77 - 28/77', '(22 - 28) / 77'),
('(22 - 28)/77', '-6 / 77'),
# FIXME: ('(7/3) * (3/5)', '7 / 5'),
# FIXME: ('(3/4) / (5/6)', '9 / 10'),
# FIXME: ('1/4 * 1/x', '1 / (4x)'),
......
from unittest import TestCase
from src.parser import Parser
from tests.parser import ParserWrapper
def rewrite(exp, **kwargs):
return ParserWrapper(Parser, **kwargs).run([exp, '@'])
from tests.rulestestcase import RulesTestCase as TestCase
class TestRewrite(TestCase):
def assertRewrite(self, rewrite_chain):
try:
for i, exp in enumerate(rewrite_chain[:-1]):
self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
except AssertionError: # pragma: nocover
print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
print 'rewrite chain:', rewrite_chain
raise
def test_addition_rewrite(self):
self.assertRewrite(['2 + 3 + 4', '5 + 4', '9'])
......@@ -26,5 +10,5 @@ class TestRewrite(TestCase):
self.assertRewrite(['2 + 3a + 4', '6 + 3a'])
def test_division_rewrite(self):
self.assertRewrite(['2/7 - 4/11', '22 / 77 + -28 / 77',
'(22 + -28) / 77', '-6 / 77'])
self.assertRewrite(['2/7 - 4/11', '22 / 77 - 28 / 77',
'(22 - 28) / 77', '-6 / 77'])
......@@ -107,7 +107,14 @@ class TestRulesFractions(RulesTestCase):
n0, n1 = root = a / b + c / b
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + c) / b)
#2 / 4 + 3 / -4 -> 2 / 4 + -3 / 4
#2 / 4 - 3 / 4 -> -1 / 4 # Equal denominators, so nominators can
n0, n1 = root = a / b + (-c / b)
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + (-c)) / b)
n0, n1 = root = a / b + -c / b
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / b)
n0, n1 = root = a / b + -(c / b)
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / b)
n0, n1 = root = a / -b + c / -b
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + c) / -b)
n0, n1 = root = a / -b + -c / -b
self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / -b)
......@@ -70,27 +70,33 @@ class TestRulesNumerics(RulesTestCase):
root = i3 * i2
self.assertEqual(match_multiply_numerics(root),
[P(root, multiply_numerics, (i3, i2))])
[P(root, multiply_numerics, (i3, i2, 3, 2))])
root = f3 * i2
self.assertEqual(match_multiply_numerics(root),
[P(root, multiply_numerics, (f3, i2))])
[P(root, multiply_numerics, (f3, i2, 3.0, 2))])
root = i3 * f2
self.assertEqual(match_multiply_numerics(root),
[P(root, multiply_numerics, (i3, f2))])
[P(root, multiply_numerics, (i3, f2, 3, 2.0))])
root = f3 * f2
self.assertEqual(match_multiply_numerics(root),
[P(root, multiply_numerics, (f3, f2))])
[P(root, multiply_numerics, (f3, f2, 3.0, 2.0))])
def test_multiply_numerics(self):
a, b, i2, i3, i6, f2, f3, f6 = tree('a,b,2,3,6,2.0,3.0,6.0')
self.assertEqual(multiply_numerics(i3 * i2, (i3, i2)), 6)
self.assertEqual(multiply_numerics(f3 * i2, (f3, i2)), 6.0)
self.assertEqual(multiply_numerics(i3 * f2, (i3, f2)), 6.0)
self.assertEqual(multiply_numerics(f3 * f2, (f3, f2)), 6.0)
self.assertEqual(multiply_numerics(i3 * i2, (i3, i2, 3, 2)), 6)
self.assertEqual(multiply_numerics(f3 * i2, (f3, i2, 3.0, 2)), 6.0)
self.assertEqual(multiply_numerics(i3 * f2, (i3, f2, 3, 2.0)), 6.0)
self.assertEqual(multiply_numerics(f3 * f2, (f3, f2, 3.0, 2.0)), 6.0)
self.assertEqualNodes(multiply_numerics(a * i3 * i2 * b, (i3, i2)),
self.assertEqualNodes(multiply_numerics(a * i3 * i2 * b, (i3, i2, 3, 2)),
a * 6 * b)
def test_multiply_numerics_negation(self):
#a, b = root = tree('1 - 5 * -3x - 5 * 6')
l1, l2 = tree('-1 * 2')
self.assertEqual(multiply_numerics(l1 * l2, (l1, l2, -1, 2)), -l2)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment