Browse Source

Fixed merge conflict, added match_extend_exponent and improved add_exponents.

Sander Mathijs van Veen 14 years ago
parent
commit
9b18453bcb
7 changed files with 174 additions and 36 deletions
  1. 43 0
      TODO
  2. 4 2
      src/rules/__init__.py
  3. 6 3
      src/rules/groups.py
  4. 58 22
      src/rules/powers.py
  5. 2 1
      tests/rulestestcase.py
  6. 54 1
      tests/test_leiden_oefenopgave.py
  7. 7 7
      tests/test_rules_powers.py

+ 43 - 0
TODO

@@ -19,3 +19,46 @@
 
  - rewrite match_combine_polynomes to an even more generic form:
    match_combine_factors.
+
+ - Fix division by zero caused by "0/0".
+
+smvv@multivac ~/work/trs $ printf "a/0\n??" | ./main.py
+Traceback (most recent call last):
+  File "./main.py", line 75, in <module>
+    main()
+  File "./main.py", line 64, in main
+    node = p.run(debug=args.debug)
+  File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 258, in run
+    self.report_last_error(filename, e)
+  File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 251, in run
+    self.engine.runEngine(debug)
+  File "bison_.pyx", line 592, in bison_.ParserEngine.runEngine (build/external/pybison/bison_.c:592)
+  File "/home/smvv/work/trs/src/parser.py", line 195, in hook_handler
+    possibilities = handler(retval)
+  File "/home/smvv/work/trs/src/rules/fractions.py", line 23, in match_constant_division
+    raise ZeroDivisionError('Division by zero: %s.' % node)
+ZeroDivisionError: Division by zero: a / 0.
+
+smvv@multivac ~/work/trs $ printf "0/0\n??" | ./main.py
+Traceback (most recent call last):
+  File "./main.py", line 75, in <module>
+    main()
+  File "./main.py", line 64, in main
+    node = p.run(debug=args.debug)
+  File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 258, in run
+    self.report_last_error(filename, e)
+  File "/home/smvv/work/trs/external/pybison/src/python/bison.py", line 251, in run
+    self.engine.runEngine(debug)
+  File "bison_.pyx", line 592, in bison_.ParserEngine.runEngine (build/external/pybison/bison_.c:592)
+  File "/home/smvv/work/trs/src/parser.py", line 195, in hook_handler
+    possibilities = handler(retval)
+  File "/home/smvv/work/trs/src/rules/numerics.py", line 73, in match_divide_numerics
+    divide = not divmod(n.value, dv)[1]
+ZeroDivisionError: integer division or modulo by zero
+
+ - Last possibilities reduce to a similar result.
+
+smvv@multivac ~/work/trs $ printf "0/1\n??" | ./main.py
+<Possibility root="0 / 1" handler=divide_numerics args=(0, 1)>
+Division of 0 by 1 reduces to 0.
+Division of 0 by 1 reduces to 0.

+ 4 - 2
src/rules/__init__.py

@@ -4,7 +4,8 @@ from .groups import match_combine_groups
 from .factors import match_expand
 from .powers import match_add_exponents, match_subtract_exponents, \
         match_multiply_exponents, match_duplicate_exponent, \
-        match_remove_negative_exponent, match_exponent_to_root
+        match_remove_negative_exponent, match_exponent_to_root, \
+        match_extend_exponent
 from .numerics import match_divide_numerics, match_multiply_numerics
 from .fractions import match_constant_division, match_add_constant_fractions, \
         match_expand_and_add_fractions
@@ -18,5 +19,6 @@ RULES = {
         OP_DIV: [match_subtract_exponents, match_divide_numerics, \
                  match_constant_division],
         OP_POW: [match_multiply_exponents, match_duplicate_exponent, \
-                 match_remove_negative_exponent, match_exponent_to_root],
+                 match_remove_negative_exponent, match_exponent_to_root, \
+                 match_extend_exponent],
         }

+ 6 - 3
src/rules/groups.py

@@ -37,13 +37,16 @@ def match_combine_groups(node):
             for i, sub_node in enumerate(scope):
                 if sub_node.is_numeric():
                     others = [scope[j] for j in range(i) + range(i + 1, l)]
-                    g = others[0] if len(others) == 1 else Node('*', *others)
+
+                    if len(others) == 1:
+                        g = others[0]
+                    else:
+                        g = Node('*', *others)
+
                     groups.append((sub_node, g, n))
 
-    #print [map(str, group) for group in groups]
     for g0, g1 in combinations(groups, 2):
         if g0[1].equals(g1[1]):
-            #print type(g0[1]), str(g0[1]), 'equals', type(g1[1]), str(g1[1])
             p.append(P(node, combine_groups, g0 + g1))
 
     return p

+ 58 - 22
src/rules/powers.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as N, ExpressionLeaf as L, \
-                   OP_NEG, OP_MUL, OP_DIV, OP_POW
+                   OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
 from ..possibilities import Possibility as P, MESSAGES
 from .utils import nary_node
 from ..translate import _
@@ -10,6 +10,9 @@ from ..translate import _
 def match_add_exponents(node):
     """
     a^p * a^q  ->  a^(p + q)
+    a * a^q    ->  a^(1 + q)
+    a^p * a    ->  a^(p + 1)
+    a * a      ->  a^(1 + 1)
     """
     assert node.is_op(OP_MUL)
 
@@ -17,26 +20,53 @@ def match_add_exponents(node):
     powers = {}
 
     for n in node.get_scope():
-        if n.is_op(OP_POW):
+        if n.is_identifier():
+            s = n
+            exponent = L(1)
+        elif n.is_op(OP_POW):
             # Order powers by their roots, e.g. a^p and a^q are put in the same
             # list because of the mutual 'a'
-            s = str(n[0])
+            s, exponent = n
+        else:
+            continue
 
-            if s in powers:
-                powers[s].append(n)
-            else:
-                powers[s] = [n]
+        s_str = str(s)
+
+        if s_str in powers:
+            powers[s_str].append((n, exponent, s))
+        else:
+            powers[s_str] = [(n, exponent, s)]
 
     for root, occurrences in powers.iteritems():
         # If a root has multiple occurences, their exponents can be added to
         # create a single power with that root
         if len(occurrences) > 1:
-            for pair in combinations(occurrences, 2):
-                p.append(P(node, add_exponents, pair))
+            for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
+                p.append(P(node, add_exponents, (n0, n1, a0, e1, e2)))
 
     return p
 
 
+def add_exponents(root, args):
+    """
+    a^p * a^q  ->  a^(p + q)
+    """
+    n0, n1, a, p, q = args
+    scope = root.get_scope()
+
+    # Replace the left node with the new expression
+    scope[scope.index(n0)] = a ** (p + q)
+
+    # Remove the right node
+    scope.remove(n1)
+
+    return nary_node('*', scope)
+
+
+MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
+        ' will reduce to {1[0]}^({1[1]} + {2[1]}).')
+
+
 def match_subtract_exponents(node):
     """
     a^p / a^q  ->  a^(p - q)
@@ -120,26 +150,32 @@ def match_exponent_to_root(node):
     return []
 
 
-def add_exponents(root, args):
+def match_extend_exponent(node):
     """
-    a^p * a^q  ->  a^(p + q)
+    (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1)  # n > 1
     """
-    n0, n1 = args
-    a, p = n0
-    q = n1[1]
-    scope = root.get_scope()
+    assert node.is_op(OP_POW)
 
-    # Replace the left node with the new expression
-    scope[scope.index(n0)] = a ** (p + q)
+    left, right = node
 
-    # Remove the right node
-    scope.remove(n1)
+    if right.is_numeric():
+        for n in node.get_scope():
+            if n.is_op(OP_ADD):
+                return [P(node, extend_exponent, (left, right))]
 
-    return nary_node('*', scope)
+    return []
 
 
-MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
-        ' will reduce to {1[0]}^({1[1]} + {2[1]}).')
+def extend_exponent(root, args):
+    """
+    (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1)  # n > 1
+    """
+    left, right = args
+
+    if right.value > 2:
+        return left * left ** L(right.value - 1)
+
+    return left * left
 
 
 def subtract_exponents(root, args):

+ 2 - 1
tests/rulestestcase.py

@@ -43,7 +43,8 @@ class RulesTestCase(unittest.TestCase):
     def assertRewrite(self, rewrite_chain):
         try:
             for i, exp in enumerate(rewrite_chain[:-1]):
-                self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
+                self.assertMultiLineEqual(str(rewrite(exp)),
+                                          str(rewrite_chain[i+1]))
         except AssertionError:  # pragma: nocover
             print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
             print 'rewrite chain:', rewrite_chain

+ 54 - 1
tests/test_leiden_oefenopgave.py

@@ -2,11 +2,14 @@ from tests.rulestestcase import RulesTestCase as TestCase, rewrite
 
 
 class TestLeidenOefenopgave(TestCase):
-    def test_1(self):
+    def test_1_1(self):
         for chain in [['-5(x2 - 3x + 6)', '-5(x ^ 2 - 3x) - 5 * 6',
                        '-5 * x ^ 2 - 5 * -3x - 5 * 6',
                        '-5 * x ^ 2 - -15x - 5 * 6',
                        # FIXME: '-5 * x ^ 2 - 5 * -3x - 30',
+                       # FIXME: '-5 * x ^ 2 - -15x - 5 * 6',
+                       # FIXME: '-5 * x ^ 2 + 15x - 5 * 6',
+                       # FIXME: '-5 * x ^ 2 + 15x - 30',
                        ], #'-30 + 15 * x - 5 * x ^ 2'],
                      ]:
             self.assertRewrite(chain)
@@ -23,6 +26,56 @@ class TestLeidenOefenopgave(TestCase):
                 ]:
             self.assertEqual(str(rewrite(exp)), solution)
 
+    def test_1_2(self):
+        for chain in [['(x+1)^3', '(x + 1)(x + 1) ^ 2',
+                '(x + 1)(x + 1)(x + 1)',
+                '(xx + x * 1 + 1x + 1 * 1)(x + 1)',
+                '(x ^ (1 + 1) + x * 1 + 1x + 1 * 1)(x + 1)',
+                '(x ^ 2 + x * 1 + 1x + 1 * 1)(x + 1)',
+                '(x ^ 2 + (1 + 1)x + 1 * 1)(x + 1)',
+                '(x ^ 2 + 2x + 1 * 1)(x + 1)',
+                '(x ^ 2 + 2x + 1)(x + 1)',
+                '(x ^ 2 + 2x)x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x * x ^ 2 + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x ^ (1 + 2) + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x ^ 3 + x * 2x + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x ^ 3 + x ^ (1 + 1) * 2 + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x ^ 3 + x ^ 2 * 2 + (x ^ 2 + 2x) * 1 + 1x + 1 * 1',
+                'x ^ 3 + x ^ 2 * 2 + 1 * x ^ 2 + 1 * 2x + 1x + 1 * 1',
+                'x ^ 3 + (2 + 1) * x ^ 2 + 1 * 2x + 1x + 1 * 1',
+                'x ^ 3 + 3 * x ^ 2 + 1 * 2x + 1x + 1 * 1',
+                'x ^ 3 + 3 * x ^ 2 + 2x + 1x + 1 * 1',
+                'x ^ 3 + 3 * x ^ 2 + (2 + 1)x + 1 * 1',
+                'x ^ 3 + 3 * x ^ 2 + 3x + 1 * 1',
+                'x ^ 3 + 3 * x ^ 2 + 3x + 1',
+                ]
+            ]:
+            self.assertRewrite(chain)
+
+    def test_1_3(self):
+        # (x+1)^2 -> x^2 + 2x + 1
+        for chain in [['(x+1)^2', '(x + 1)(x + 1)',
+                       'xx + x * 1 + 1x + 1 * 1',
+                       'x ^ (1 + 1) + x * 1 + 1x + 1 * 1',
+                       'x ^ 2 + x * 1 + 1x + 1 * 1',
+                       'x ^ 2 + (1 + 1)x + 1 * 1',
+                       'x ^ 2 + 2x + 1 * 1',
+                       'x ^ 2 + 2x + 1'],
+                     ]:
+            self.assertRewrite(chain)
+
+    def test_1_4(self):
+        # (x-1)^2 -> x^2 - 2x + 1
+        for chain in [['(x-1)^2', '(x - 1)(x - 1)',
+                       'xx + x * -1 - 1x - 1 * -1',
+                       'x ^ (1 + 1) + x * -1 - 1x - 1 * -1',
+                       'x ^ 2 + x * -1 - 1x - 1 * -1',
+                       # FIXME: 'x ^ 2 + (-1 - 1)x - 1 * -1',
+                       # FIXME: 'x ^ 2 - 2x - 1 * -1',
+                       # FIXME: 'x ^ 2 - 2x + 1',
+                     ]]:
+            self.assertRewrite(chain)
+
     def test_2(self):
         pass
 

+ 7 - 7
tests/test_rules_powers.py

@@ -17,7 +17,7 @@ class TestRulesPowers(RulesTestCase):
 
         possibilities = match_add_exponents(root)
         self.assertEqualPos(possibilities,
-                [P(root, add_exponents, (n0, n1))])
+                [P(root, add_exponents, (n0, n1, a, p, q))])
 
     def test_match_add_exponents_ternary(self):
         a, p, q, r = tree('a,p,q,r')
@@ -25,9 +25,9 @@ class TestRulesPowers(RulesTestCase):
 
         possibilities = match_add_exponents(root)
         self.assertEqualPos(possibilities,
-                [P(root, add_exponents, (n0, n1)),
-                 P(root, add_exponents, (n0, n2)),
-                 P(root, add_exponents, (n1, n2))])
+                [P(root, add_exponents, (n0, n1, a, p, q)),
+                 P(root, add_exponents, (n0, n2, a, p, r)),
+                 P(root, add_exponents, (n1, n2, a, q, r))])
 
     def test_match_add_exponents_multiple_identifiers(self):
         a, b, p, q = tree('a,b,p,q')
@@ -35,8 +35,8 @@ class TestRulesPowers(RulesTestCase):
 
         possibilities = match_add_exponents(root)
         self.assertEqualPos(possibilities,
-                [P(root, add_exponents, (a0, a1)),
-                 P(root, add_exponents, (b0, b1))])
+                [P(root, add_exponents, (a0, a1, a, p, q)),
+                 P(root, add_exponents, (b0, b1, b, p, q))])
 
     def test_match_subtract_exponents_powers(self):
         a, p, q = tree('a,p,q')
@@ -103,7 +103,7 @@ class TestRulesPowers(RulesTestCase):
         a, p, q = tree('a,p,q')
         n0, n1 = root = a ** p * a ** q
 
-        self.assertEqualNodes(add_exponents(root, (n0, n1)), a ** (p + q))
+        self.assertEqualNodes(add_exponents(root, (n0, n1, a, p, q)), a ** (p + q))
 
     def test_subtract_exponents(self):
         a, p, q = tree('a,p,q')