浏览代码

Added support for OP_NEG in multiply_numerics and add_nominators.

Sander Mathijs van Veen 14 年之前
父节点
当前提交
388752b486

+ 9 - 0
src/node.py

@@ -90,6 +90,15 @@ class ExpressionBase(object):
             # Self is a leaf, thus has less value than an expression node.
             return True
 
+        if self.is_op(OP_NEG) and self[0].is_leaf():
+            if other.is_leaf():
+                # Both are leafs, string compare the value.
+                return ('-' + str(self.value)) < str(other.value)
+            if other.is_op(OP_NEG) and other[0].is_leaf():
+                return ('-' + str(self.value)) < ('-' + str(other.value))
+            # Self is a leaf, thus has less value than an expression node.
+            return True
+
         if other.is_leaf():
             # Self is an expression node, and the other is a leaf. Thus, other
             # is greater than self.

+ 2 - 2
src/rules/factors.py

@@ -1,7 +1,7 @@
 from itertools import product, combinations
 
 from .utils import nary_node
-from ..node import OP_ADD, OP_MUL
+from ..node import OP_ADD, OP_MUL, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -19,7 +19,7 @@ def match_expand(node):
     additions = []
 
     for n in node.get_scope():
-        if n.is_leaf():
+        if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
             leaves.append(n)
         elif n.op == OP_ADD:
             additions.append(n)

+ 24 - 10
src/rules/fractions.py

@@ -87,8 +87,15 @@ def match_add_constant_fractions(node):
     fractions = filter(is_division, node.get_scope())
 
     for a, b in combinations(fractions, 2):
-        na, da = a if a.is_op(OP_DIV) else a[0]
-        nb, db = b if b.is_op(OP_DIV) else b[0]
+        if a.is_op(OP_NEG):
+            na, da = a[0]
+        else:
+            na, da = a
+
+        if b.is_op(OP_NEG):
+            nb, db = b[0]
+        else:
+            nb, db = b
 
         if da == db:
             # Equal denominators, add nominators to create a single fraction
@@ -133,19 +140,24 @@ MESSAGES[equalize_denominators] = _('Equalize the denominators of division'
 
 def add_nominators(root, args):
     """
-    a / b + c / b     ->  (a + c) / b
-    a / b + (-c / b)  ->  (a + (-c)) / b
+    a / b + c / b    ->  (a + c) / b
+    a / -b + c / -b  ->  (a + c) / -b
+    a / -b - c / -b  ->  (a - c) / -b
     """
     # TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"?
     ab, cb = args
-    a, b = ab
 
-    if cb[0].is_op(OP_NEG):
-        c = cb[0][0]
-        substitution = (a + (-c)) / b
+    if ab.is_op(OP_NEG):
+        a, b = ab[0]
+    else:
+        a, b = ab
+
+    if cb.is_op(OP_NEG):
+        c = -cb[0][0]
     else:
         c = cb[0]
-        substitution = (a + c) / b
+
+    substitution = (a + c) / b
 
     scope = root.get_scope()
 
@@ -158,7 +170,9 @@ def add_nominators(root, args):
     return nary_node('+', scope)
 
 
-MESSAGES[add_nominators] = _('Add nominators of the division of {1} by {2}.')
+# TODO: convert this to a lambda. Example: 22 / 77 - 28 / 77. the "-" is above
+# the "28/77" division.
+MESSAGES[add_nominators] = _('Add nominators {1[0]} and {2[0]} of the division.')
 
 
 def match_expand_and_add_fractions(node):

+ 26 - 8
src/rules/numerics.py

@@ -18,8 +18,15 @@ def add_numerics(root, args):
     """
     n0, n1, c0, c1 = args
 
-    c0 = (-c0[0].value) if c0.is_op(OP_NEG) else c0.value
-    c1 = (-c1[0].value) if c1.is_op(OP_NEG) else c1.value
+    if c0.is_op(OP_NEG):
+        c0 = (-c0[0].value)
+    else:
+        c0 = c0.value
+
+    if c1.is_op(OP_NEG):
+        c1 = (-c1[0].value)
+    else:
+        c1 = c1.value
 
     scope = root.get_scope()
 
@@ -110,11 +117,16 @@ def match_multiply_numerics(node):
     assert node.is_op(OP_MUL)
 
     p = []
-    scope = node.get_scope()
-    numerics = filter(lambda n: n.is_numeric(), scope)
+    numerics = []
+
+    for n in node.get_scope():
+        if n.is_numeric():
+            numerics.append((n, n.value))
+        elif n.is_op(OP_NEG) and n[0].is_numeric():
+            numerics.append((n, n[0].value))
 
-    for args in combinations(numerics, 2):
-        p.append(P(node, multiply_numerics, args))
+    for (n0, v0), (n1, v1) in combinations(numerics, 2):
+        p.append(P(node, multiply_numerics, (n0, n1, v0, v1)))
 
     return p
 
@@ -126,13 +138,19 @@ def multiply_numerics(root, args):
     Example:
     2 * 3  ->  6
     """
-    n0, n1 = args
+    n0, n1, v0, v1 = args
     scope = []
+    value = v0 * v1
+
+    if value > 0:
+        substitution = Leaf(value)
+    else:
+        substitution = -Leaf(-value)
 
     for n in root.get_scope():
         if hash(n) == hash(n0):
             # Replace the left node with the new expression
-            scope.append(Leaf(n0.value * n1.value))
+            scope.append(substitution)
             #scope.append(n)
         elif hash(n) != hash(n1):
             # Remove the right node

+ 4 - 1
src/rules/poly.py

@@ -79,7 +79,10 @@ def combine_polynomes(root, args):
     n0, n1, c0, c1, r, e = args
 
     # a ^ 1 -> a
-    power = r if e == 1 else r ** e
+    if e == 1:
+        power = r
+    else:
+        power = r ** e
 
     # replacement: (c0 + c1) * a ^ b
     # a, b and c are from 'left', d is from 'right'.

+ 13 - 0
tests/rulestestcase.py

@@ -8,6 +8,10 @@ def tree(exp, **kwargs):
     return ParserWrapper(Parser, **kwargs).run([exp])
 
 
+def rewrite(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
+
+
 class RulesTestCase(unittest.TestCase):
 
     def assertEqualPos(self, possibilities, expected):
@@ -35,3 +39,12 @@ class RulesTestCase(unittest.TestCase):
 
         for ca, cb in zip(a, b):
             self.assertEqualNodes(ca, cb)
+
+    def assertRewrite(self, rewrite_chain):
+        try:
+            for i, exp in enumerate(rewrite_chain[:-1]):
+                self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
+        except AssertionError:  # pragma: nocover
+            print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
+            print 'rewrite chain:', rewrite_chain
+            raise

+ 13 - 12
tests/test_leiden_oefenopgave.py

@@ -1,18 +1,19 @@
-from unittest import TestCase
-
-from src.parser import Parser
-from tests.parser import ParserWrapper
-
-
-def rewrite(exp, **kwargs):
-    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
+from tests.rulestestcase import RulesTestCase as TestCase, rewrite
 
 
 class TestLeidenOefenopgave(TestCase):
     def test_1(self):
+        for chain in [['-5(x2 - 3x + 6)', '-5(x ^ 2 - 3x) - 5 * 6',
+                       # FIXME: '-5 * x ^ 2 - 5 * -3x - 5 * 6',
+                       # FIXME: '-5 * x ^ 2 - 5 * -3x - 30',
+                       ], #'-30 + 15 * x - 5 * x ^ 2'],
+                     ]:
+            self.assertRewrite(chain)
+
         return
+
         for exp, solution in [
-                ('-5(x2 -3x + 6)',       '-30 + 15 * x - 5 * x ^ 2'),
+                ('-5(x2 - 3x + 6)',       '-30 + 15 * x - 5 * x ^ 2'),
                 ('(x+1)^2',              'x ^ 2 + 2 * x + 1'),
                 ('(x-1)^2',              'x ^ 2 - 2 * x + 1'),
                 ('(2x+x)*x',             '3 * x ^ 2'),
@@ -32,9 +33,9 @@ class TestLeidenOefenopgave(TestCase):
                 ('2/15 + 1/4',      '8 / 60 + 15 / 60'),
                 ('8/60 + 15/60',    '(8 + 15) / 60'),
                 ('(8 + 15) / 60',   '23 / 60'),
-                ('2/7 - 4/11',      '22 / 77 + -28 / 77'),
-                ('22/77 + -28/77',  '(22 + -28) / 77'),
-                ('(22 + -28)/77',    '-6 / 77'),
+                ('2/7 - 4/11',      '22 / 77 - 28 / 77'),
+                ('22/77 - 28/77',  '(22 - 28) / 77'),
+                ('(22 - 28)/77',    '-6 / 77'),
                 # FIXME: ('(7/3) * (3/5)',   '7 / 5'),
                 # FIXME: ('(3/4) / (5/6)',   '9 / 10'),
                 # FIXME: ('1/4 * 1/x',       '1 / (4x)'),

+ 3 - 19
tests/test_rewrite.py

@@ -1,24 +1,8 @@
-from unittest import TestCase
-
-from src.parser import Parser
-from tests.parser import ParserWrapper
-
-
-def rewrite(exp, **kwargs):
-    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
+from tests.rulestestcase import RulesTestCase as TestCase
 
 
 class TestRewrite(TestCase):
 
-    def assertRewrite(self, rewrite_chain):
-        try:
-            for i, exp in enumerate(rewrite_chain[:-1]):
-                self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
-        except AssertionError:  # pragma: nocover
-            print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
-            print 'rewrite chain:', rewrite_chain
-            raise
-
     def test_addition_rewrite(self):
         self.assertRewrite(['2 + 3 + 4', '5 + 4', '9'])
 
@@ -26,5 +10,5 @@ class TestRewrite(TestCase):
         self.assertRewrite(['2 + 3a + 4', '6 + 3a'])
 
     def test_division_rewrite(self):
-        self.assertRewrite(['2/7 - 4/11', '22 / 77 + -28 / 77',
-                            '(22 + -28) / 77', '-6 / 77'])
+        self.assertRewrite(['2/7 - 4/11', '22 / 77 - 28 / 77',
+                            '(22 - 28) / 77', '-6 / 77'])

+ 11 - 4
tests/test_rules_fractions.py

@@ -107,7 +107,14 @@ class TestRulesFractions(RulesTestCase):
         n0, n1 = root = a / b + c / b
         self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + c) / b)
 
-        #2 / 4 + 3 / -4  ->  2 / 4 + -3 / 4
-        #2 / 4 - 3 / 4  ->  -1 / 4  # Equal denominators, so nominators can
-        n0, n1 = root = a / b + (-c / b)
-        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + (-c)) / b)
+        n0, n1 = root = a / b + -c / b
+        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / b)
+
+        n0, n1 = root = a / b + -(c / b)
+        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / b)
+
+        n0, n1 = root = a / -b + c / -b
+        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + c) / -b)
+
+        n0, n1 = root = a / -b + -c / -b
+        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + -c) / -b)

+ 15 - 9
tests/test_rules_numerics.py

@@ -70,27 +70,33 @@ class TestRulesNumerics(RulesTestCase):
 
         root = i3 * i2
         self.assertEqual(match_multiply_numerics(root),
-                [P(root, multiply_numerics, (i3, i2))])
+                [P(root, multiply_numerics, (i3, i2, 3, 2))])
 
         root = f3 * i2
         self.assertEqual(match_multiply_numerics(root),
-                [P(root, multiply_numerics, (f3, i2))])
+                [P(root, multiply_numerics, (f3, i2, 3.0, 2))])
 
         root = i3 * f2
         self.assertEqual(match_multiply_numerics(root),
-                [P(root, multiply_numerics, (i3, f2))])
+                [P(root, multiply_numerics, (i3, f2, 3, 2.0))])
 
         root = f3 * f2
         self.assertEqual(match_multiply_numerics(root),
-                [P(root, multiply_numerics, (f3, f2))])
+                [P(root, multiply_numerics, (f3, f2, 3.0, 2.0))])
 
     def test_multiply_numerics(self):
         a, b, i2, i3, i6, f2, f3, f6 = tree('a,b,2,3,6,2.0,3.0,6.0')
 
-        self.assertEqual(multiply_numerics(i3 * i2, (i3, i2)), 6)
-        self.assertEqual(multiply_numerics(f3 * i2, (f3, i2)), 6.0)
-        self.assertEqual(multiply_numerics(i3 * f2, (i3, f2)), 6.0)
-        self.assertEqual(multiply_numerics(f3 * f2, (f3, f2)), 6.0)
+        self.assertEqual(multiply_numerics(i3 * i2, (i3, i2, 3, 2)), 6)
+        self.assertEqual(multiply_numerics(f3 * i2, (f3, i2, 3.0, 2)), 6.0)
+        self.assertEqual(multiply_numerics(i3 * f2, (i3, f2, 3, 2.0)), 6.0)
+        self.assertEqual(multiply_numerics(f3 * f2, (f3, f2, 3.0, 2.0)), 6.0)
 
-        self.assertEqualNodes(multiply_numerics(a * i3 * i2 * b, (i3, i2)),
+        self.assertEqualNodes(multiply_numerics(a * i3 * i2 * b, (i3, i2, 3, 2)),
                               a * 6 * b)
+
+    def test_multiply_numerics_negation(self):
+        #a, b = root = tree('1 - 5 * -3x - 5 * 6')
+        l1, l2 = tree('-1 * 2')
+
+        self.assertEqual(multiply_numerics(l1 * l2, (l1, l2, -1, 2)), -l2)