Răsfoiți Sursa

Fixed merge conflict.

Sander Mathijs van Veen 14 ani în urmă
părinte
comite
e86e7b9c01

+ 14 - 0
src/node.py

@@ -1,6 +1,7 @@
 # vim: set fileencoding=utf-8 :
 import os.path
 import sys
+import copy
 
 sys.path.insert(0, os.path.realpath('external'))
 
@@ -59,6 +60,9 @@ def to_expression(obj):
 
 
 class ExpressionBase(object):
+    def clone(self):
+        return copy.deepcopy(self)
+
     def __lt__(self, other):
         """
         Comparison between this expression{node,leaf} and another
@@ -161,6 +165,9 @@ class ExpressionNode(Node, ExpressionBase):
 
         return False
 
+    def substitute(self, old_child, new_child):
+        self.nodes[self.nodes.index(old_child)] = new_child
+
     def graph(self):  # pragma: nocover
         return generate_graph(self)
 
@@ -182,6 +189,9 @@ class ExpressionNode(Node, ExpressionBase):
         >>> n2 = N('*', N('^', r, e), c)
         >>> n2.extract_polynome()
         (c, r, e)
+        >>> n3 = N('-', r)
+        >>> n3.extract_polynome()
+        (1, -r, 1)
         """
         # TODO: change "get_polynome" -> "extract_polynome".
         # TODO: change retval of c * r ^ e to (c, r, e).
@@ -191,6 +201,10 @@ class ExpressionNode(Node, ExpressionBase):
         if self.is_power():
             return (ExpressionLeaf(1), self[0], self[1])
 
+        # rule: -r -> (1, r, 1)
+        if self.is_op(OP_NEG):
+            return (ExpressionLeaf(1), -self[0], ExpressionLeaf(1))
+
         if self.op != OP_MUL:
             return
 

+ 44 - 8
src/parser.py

@@ -79,6 +79,9 @@ class Parser(BisonParser):
         self.read_buffer = ''
         self.read_queue = Queue.Queue()
 
+        self.subtree_map = {}
+        self.root_node = None
+
     # Override default read method with a version that prompts for input.
     def read(self, nbytes):
         if self.file == sys.stdin and self.file.closed:
@@ -181,11 +184,43 @@ class Parser(BisonParser):
                 or retval.type != TYPE_OPERATOR or retval.op not in RULES:
             return retval
 
+        # Update the subtree map to let the subtree point to its parent node.
+        parent_nodes = self.subtree_map.keys()
+
+        for child in retval:
+            if child in parent_nodes:
+                self.subtree_map[child] = retval
+
         for handler in RULES[retval.op]:
-            self.possibilities.extend(handler(retval))
+            possibilities = handler(retval)
+
+            # Record the subtree root node in order to avoid tree traversal.
+            # At this moment, the node is the root node since the expression is
+            # parser using the left-innermost parsing strategy.
+            for p in possibilities:
+                self.subtree_map[p.root] = None
+
+            self.possibilities.extend(possibilities)
 
         return retval
 
+    def display_hint(self):
+        print pick_suggestion(self.last_possibilities)
+
+    def display_possibilities(self):
+        print '\n'.join(map(str, self.last_possibilities))
+
+    def rewrite(self):
+        suggestion = pick_suggestion(self.last_possibilities)
+
+        if not suggestion:
+            return self.root_node
+
+        expression = apply_suggestion(self.root_node, self.subtree_map,
+                                    suggestion)
+        self.read_queue.put_nowait(str(expression))
+        return expression
+
     #def hook_run(self, filename, retval):
     #    return retval
 
@@ -226,22 +261,23 @@ class Parser(BisonParser):
              | REWRITE NEWLINE
              | RAISE NEWLINE
         """
-        if option in [1, 2]:  # rule: EXP NEWLINE | DEBUG NEWLINE
+        if option == 1:  # rule: EXP NEWLINE
+            self.root_node = values[0]
+            return values[0]
+
+        if option == 2:  # rule: DEBUG NEWLINE
             return values[0]
 
         if option == 3:  # rule: HINT NEWLINE
-            print pick_suggestion(self.last_possibilities)
+            self.display_hint()
             return
 
         if option == 4:  # rule: POSSIBILITIES NEWLINE
-            print '\n'.join(map(str, self.last_possibilities))
+            self.display_possibilities()
             return
 
         if option == 5:  # rule: REWRITE NEWLINE
-            suggestion = pick_suggestion(self.last_possibilities)
-            expression = apply_suggestion(suggestion)
-            self.read_queue.put_nowait(str(expression))
-            return expression
+            return self.rewrite()
 
         if option == 6:
             raise RuntimeError('on_line: exception raised')

+ 32 - 6
src/possibilities.py

@@ -15,8 +15,7 @@ class Possibility(object):
         if self.handler in MESSAGES:
             return MESSAGES[self.handler].format(self.root, *self.args)
 
-        return '<Possibility root="%s" handler=%s args=%s>' \
-                % (self.root, self.handler.func_name, self.args)
+        return self.__repr__()
 
     def __repr__(self):
         return '<Possibility root="%s" handler=%s args=%s>' \
@@ -32,11 +31,11 @@ def filter_duplicates(possibilities):
     """
     Filter duplicated possibilities. Duplicated possibilities occur in n-ary
     nodes, the root-level node and a lower-level node will both recognize a
-    rewrite possibility within their sscope, whereas only the root-level one
+    rewrite possibility within their scope, whereas only the root-level one
     matters.
 
     Example: 1 + 2 + 3
-    The addition of 1 and 2 is recognized bij n-ary additions "1 + 2" and
+    The addition of 1 and 2 is recognized by n-ary additions "1 + 2" and
     "1 + 2 + 3". The "1 + 2" addition should be removed by this function.
     """
     features = []
@@ -53,10 +52,37 @@ def filter_duplicates(possibilities):
 
 
 def pick_suggestion(possibilities):
+    if not possibilities:
+        return
+
     # TODO: pick the best suggestion.
     suggestion = 0
     return possibilities[suggestion]
 
 
-def apply_suggestion(suggestion):
-    return suggestion.handler(suggestion.root, suggestion.args)
+def apply_suggestion(root, subtree_map, suggestion):
+    # clone the root node before modifying. After deep copying the root node,
+    # the subtree_map cannot be used since the hash() of each node in the deep
+    # copied root node has changed.
+    #root_clone = root.clone()
+
+    subtree = suggestion.handler(suggestion.root, suggestion.args)
+
+    if suggestion.root in subtree_map:
+        parent_node = subtree_map[suggestion.root]
+    else:
+        parent_node = None
+
+    # There is either a parent node or the subtree is the root node.
+    # FIXME: FAIL: test_diagnostic_test_application in tests/test_b1_ch08.py
+    #try:
+    #    assert bool(parent_node) != (subtree == root)
+    #except:
+    #    print 'parent_node: %s' % (str(parent_node))
+    #    print 'subtree: %s == %s' % (str(subtree), str(root))
+    #    raise
+
+    if parent_node:
+        parent_node.substitute(suggestion.root, subtree)
+        return root
+    return subtree

+ 0 - 2
src/rules/factors.py

@@ -13,8 +13,6 @@ def match_expand(node):
     """
     assert node.is_op(OP_MUL)
 
-    scope = node.get_scope()
-
     p = []
     leaves = []
     additions = []

+ 32 - 10
src/rules/fractions.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from .utils import nary_node, least_common_multiple
-from ..node import ExpressionLeaf as L, OP_DIV, OP_ADD, OP_MUL
+from ..node import ExpressionLeaf as L, OP_DIV, OP_ADD, OP_MUL, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -63,15 +63,23 @@ def match_add_constant_fractions(node):
     1 / 2 + 3 / 4  ->  2 / 4 + 3 / 4  # Equalize denominators
     2 / 4 + 3 / 4  ->  5 / 4          # Equal denominators, so nominators can
                                       # be added
+    2 / 2 - 3 / 4  ->  4 / 4 - 3 / 4  # Equalize denominators
+    2 / 4 - 3 / 4  ->  -1 / 4         # Equal denominators, so nominators can
+                                      # be subtracted
     """
     assert node.is_op(OP_ADD)
 
     p = []
-    fractions = filter(lambda n: n.is_op(OP_DIV), node.get_scope())
+
+    def is_division(node):
+        return node.is_op(OP_DIV) or \
+                (node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
+
+    fractions = filter(is_division, node.get_scope())
 
     for a, b in combinations(fractions, 2):
-        na, da = a
-        nb, db = b
+        na, da = a if a.is_op(OP_DIV) else a[0]
+        nb, db = b if b.is_op(OP_DIV) else b[0]
 
         if da == db:
             # Equal denominators, add nominators to create a single fraction
@@ -96,28 +104,40 @@ def equalize_denominators(root, args):
     scope = root.get_scope()
 
     for fraction in args[:2]:
-        n, d = fraction
+        n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction
         mult = denom / d.value
 
         if mult != 1:
             n = L(n.value * mult) if n.is_numeric() else L(mult) * n
-            scope[scope.index(fraction)] = n / L(d.value * mult)
+
+            if fraction.is_op(OP_NEG):
+                scope[scope.index(fraction)] = -(n / L(d.value * mult))
+            else:
+                scope[scope.index(fraction)] = n / L(d.value * mult)
 
     return nary_node('+', scope)
 
 
 def add_nominators(root, args):
     """
-    a / b + c / b  ->  (a + c) / b
+    a / b + c / b     ->  (a + c) / b
+    a / b + (-c / b)  ->  (a + (-c)) / b
     """
+    # TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"?
     ab, cb = args
     a, b = ab
-    c = cb[0]
+
+    if cb[0].is_op(OP_NEG):
+        c = cb[0][0]
+        substitution = (a + (-c)) / b
+    else:
+        c = cb[0]
+        substitution = (a + c) / b
 
     scope = root.get_scope()
 
     # Replace the left node with the new expression
-    scope[scope.index(ab)] = (a + c) / b
+    scope[scope.index(ab)] = substitution
 
     # Remove the right node
     scope.remove(cb)
@@ -127,8 +147,10 @@ def add_nominators(root, args):
 
 def match_expand_and_add_fractions(node):
     """
-    a * b / c + d * b / c  ->  (a + d) * (b / c)
+    a * b / c + d * b / c      ->  (a + d) * (b / c)
+    a * b / c + (- d * b / c)  ->  (a + (-d)) * (b / c)
     """
+    # TODO: is 'add' Appropriate when rewriting to "(a + (-d)) / * (b / c)"?
     assert node.is_op(OP_MUL)
 
     p = []

+ 1 - 0
src/rules/groups.py

@@ -18,6 +18,7 @@ def match_combine_groups(node):
     ab + 2ab  ->  3ab
     ab + ba   ->  2ab
     """
+    # TODO: handle OP_NEG nodes
     assert node.is_op(OP_ADD)
 
     p = []

+ 10 - 4
src/rules/numerics.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from .utils import nary_node
-from ..node import ExpressionLeaf as Leaf, OP_DIV, OP_MUL
+from ..node import ExpressionLeaf as Leaf, OP_DIV, OP_MUL, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -11,10 +11,16 @@ def add_numerics(root, args):
     Combine two constants to a single constant in an n-ary addition.
 
     Example:
-    2 + 3  ->  5
+    2 + 3    ->  5
+    2 + -3   ->  -1
+    -2 + 3   ->  1
+    -2 + -3  ->  -5
     """
     n0, n1, c0, c1 = args
 
+    c0 = (-c0[0].value) if c0.is_op(OP_NEG) else c0.value
+    c1 = (-c1[0].value) if c1.is_op(OP_NEG) else c1.value
+
     scope = root.get_scope()
 
     # Replace the left node with the new expression
@@ -26,8 +32,8 @@ def add_numerics(root, args):
     return nary_node('+', scope)
 
 
-MESSAGES[add_numerics] = _('Combine the constants {3} and {4}, which'
-        ' will reduce to {3} + {4}.')
+MESSAGES[add_numerics] = _('Combine the constants {1} and {2}, which'
+        ' will reduce to {1} + {2}.')
 
 
 #def match_subtract_numerics(node):

+ 11 - 4
src/rules/poly.py

@@ -1,11 +1,15 @@
 from itertools import combinations
 
-from ..node import OP_ADD
+from ..node import OP_ADD, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from .utils import nary_node
 from .numerics import add_numerics
 
 
+def is_numeric_or_negated_numeric(n):
+    return n.is_numeric() or (n.is_op(OP_NEG) and n[0].is_numeric())
+
+
 def match_combine_polynomes(node, verbose=False):
     """
     n + exp + m -> exp + (n + m)
@@ -49,9 +53,12 @@ def match_combine_polynomes(node, verbose=False):
             # roots, or: same root and exponent -> combine coefficients.
             # TODO: Addition with zero, e.g. a + 0 -> a
             if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \
-                    and r0.is_numeric() and r1.is_numeric():
-                # 2 + 3 -> 5
-                p.append(P(node, add_numerics, (n0, n1, r0.value, r1.value)))
+                    and all(map(is_numeric_or_negated_numeric, [r0, r1])):
+                # 2 + 3    ->  5
+                # 2 + -3   ->  -1
+                # -2 + 3   ->  1
+                # -2 + -3  ->  -5
+                p.append(P(node, add_numerics, (n0, n1, r0, r1)))
             elif c0.is_numeric() and c1.is_numeric() and r0 == r1 and e0 == e1:
                 # 2a + 2a -> 4a
                 # a + 2a -> 3a

+ 2 - 0
tests/parser.py

@@ -120,6 +120,8 @@ def apply_expressions(base_class, expressions, fail=True, silent=False,
 
             if fail:
                 raise
+
+
 def graph(parser, *exp, **kwargs):
     return generate_graph(ParserWrapper(parser, **kwargs).run(exp))
 

+ 15 - 12
tests/test_leiden_oefenopgave.py

@@ -4,8 +4,8 @@ from src.parser import Parser
 from tests.parser import ParserWrapper
 
 
-def reduce(exp, **kwargs):
-    return ParserWrapper(Parser, **kwargs).run([exp]).reduce()
+def rewrite(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
 
 
 class TestLeidenOefenopgave(TestCase):
@@ -19,7 +19,7 @@ class TestLeidenOefenopgave(TestCase):
                 ('-2(6x-4)^2*x',         '-72 * x^3 + 96 * x ^ 2 + 32 * x'),
                 ('(4x + 5) * -(5 - 4x)', '16x^2 - 25'),
                 ]:
-            self.assertEqual(str(reduce(exp)), solution)
+            self.assertEqual(str(rewrite(exp)), solution)
 
     def test_2(self):
         pass
@@ -28,14 +28,17 @@ class TestLeidenOefenopgave(TestCase):
         pass
 
     def test_4(self):
-        return
         for exp, solution in [
-                ('2/15 + 1/4',      '23/60'),
-                ('2/7 - 4/11',      '-6/77'),
-                ('(7/3) * (3/5)',   '7/5'),
-                ('(3/4) / (5/6)',   '9/10'),
-                ('1/4 * 1/x',       '1/(4x)'),
-                ('(3/x^2) / (x/7)', '21/x^3'),
-                ('1/x + 2/(x+1)',   '(3x + 1) / (x * (x + 1))'),
+                ('2/15 + 1/4',      '8 / 60 + 15 / 60'),
+                ('8/60 + 15/60',    '(8 + 15) / 60'),
+                ('(8 + 15) / 60',   '23 / 60'),
+                ('2/7 - 4/11',      '22 / 77 + -28 / 77'),
+                ('22/77 + -28/77',  '(22 + -28) / 77'),
+                ('(22 + -28)/77',    '-6 / 77'),
+                # FIXME: ('(7/3) * (3/5)',   '7 / 5'),
+                # FIXME: ('(3/4) / (5/6)',   '9 / 10'),
+                # FIXME: ('1/4 * 1/x',       '1 / (4x)'),
+                # FIXME: ('(3/x^2) / (x/7)', '21 / x^3'),
+                # FIXME: ('1/x + 2/(x+1)',   '(3x + 1) / (x * (x + 1))'),
                 ]:
-            self.assertEqual(str(reduce(exp)), solution)
+            self.assertEqual(str(rewrite(exp)), solution)

+ 30 - 0
tests/test_rewrite.py

@@ -0,0 +1,30 @@
+from unittest import TestCase
+
+from src.parser import Parser
+from tests.parser import ParserWrapper
+
+
+def rewrite(exp, **kwargs):
+    return ParserWrapper(Parser, **kwargs).run([exp, '@'])
+
+
+class TestRewrite(TestCase):
+
+    def assertRewrite(self, rewrite_chain):
+        try:
+            for i, exp in enumerate(rewrite_chain[:-1]):
+                self.assertEqual(str(rewrite(exp)), str(rewrite_chain[i+1]))
+        except AssertionError:  # pragma: nocover
+            print 'rewrite failed:', exp, '->', rewrite_chain[i+1]
+            print 'rewrite chain:', rewrite_chain
+            raise
+
+    def test_addition_rewrite(self):
+        self.assertRewrite(['2 + 3 + 4', '5 + 4', '9'])
+
+    def test_addition_identifiers_rewrite(self):
+        self.assertRewrite(['2 + 3a + 4', '6 + 3a'])
+
+    def test_division_rewrite(self):
+        self.assertRewrite(['2/7 - 4/11', '22 / 77 + -28 / 77',
+                            '(22 + -28) / 77', '-6 / 77'])

+ 28 - 1
tests/test_rules_fractions.py

@@ -68,6 +68,19 @@ class TestRulesFractions(RulesTestCase):
         self.assertEqualPos(possibilities,
                 [P(root, add_nominators, (n1, n3))])
 
+    def test_add_constant_fractions_with_negation(self):
+        a, b, c, l1, l2, l3, l4 = tree('a,b,c,1,2,3,4')
+
+        (((n0, n1), n2), n3), n4 = root = a + l2 / l2 + b + (-l3 / l4) + c
+        possibilities = match_add_constant_fractions(root)
+        self.assertEqualPos(possibilities,
+                [P(root, equalize_denominators, (n1, n3, 4))])
+
+        (((n0, n1), n2), n3), n4 = root = a + l2 / l4 + b + (-l3 / l4) + c
+        possibilities = match_add_constant_fractions(root)
+        self.assertEqualPos(possibilities,
+                [P(root, add_nominators, (n1, n3))])
+
     def test_equalize_denominators(self):
         a, b, l1, l2, l3, l4 = tree('a,b,1,2,3,4')
 
@@ -79,8 +92,22 @@ class TestRulesFractions(RulesTestCase):
         self.assertEqualNodes(equalize_denominators(root, (n0, n1, 4)),
                               (l2 * a) / l4 + b / l4)
 
+        #2 / 2 - 3 / 4  ->  4 / 4 - 3 / 4  # Equalize denominators
+        n0, n1 = root = l1 / l2 + (-l3 / l4)
+        self.assertEqualNodes(equalize_denominators(root, (n0, n1, 4)),
+                              l2 / l4 + (-l3 / l4))
+
+        #2 / 2 - 3 / 4  ->  4 / 4 - 3 / 4  # Equalize denominators
+        n0, n1 = root = a / l2 + (-b / l4)
+        self.assertEqualNodes(equalize_denominators(root, (n0, n1, 4)),
+                              (l2 * a) / l4 + (-b / l4))
+
     def test_add_nominators(self):
         a, b, c = tree('a,b,c')
         n0, n1 = root = a / b + c / b
-
         self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + c) / b)
+
+        #2 / 4 + 3 / -4  ->  2 / 4 + -3 / 4
+        #2 / 4 - 3 / 4  ->  -1 / 4  # Equal denominators, so nominators can
+        n0, n1 = root = a / b + (-c / b)
+        self.assertEqualNodes(add_nominators(root, (n0, n1)), (a + (-c)) / b)

+ 10 - 2
tests/test_rules_numerics.py

@@ -10,8 +10,16 @@ class TestRulesNumerics(RulesTestCase):
     def test_add_numerics(self):
         l0, a, l1 = tree('1,a,2')
 
-        self.assertEqual(add_numerics(l0 + l1, (l0, l1, 1, 2)), 3)
-        self.assertEqual(add_numerics(l0 + a + l1, (l0, l1, 1, 2)), L(3) + a)
+        self.assertEqual(add_numerics(l0 + l1, (l0, l1, L(1), L(2))), 3)
+        self.assertEqual(add_numerics(l0 + a + l1, (l0, l1, L(1), L(2))),
+                         L(3) + a)
+
+    def test_add_numerics_negations(self):
+        l0, a, l1 = tree('1,a,2')
+
+        self.assertEqual(add_numerics(l0 + -l1, (l0, -l1, L(1), -L(2))), -1)
+        self.assertEqual(add_numerics(l0 + a + -l1, (l0, -l1, L(1), -L(2))),
+                         L(-1) + a)
 
     def test_match_divide_numerics(self):
         a, b, i2, i3, i6, f1, f2, f3 = tree('a,b,2,3,6,1.0,2.0,3.0')

+ 0 - 2
tests/test_rules_poly.py

@@ -36,14 +36,12 @@ class TestRulesPoly(RulesTestCase):
         self.assertEqualPos(possibilities,
                 [P(root, combine_polynomes, (a1, a2, 2, 1, 'a', 3))])
 
-
     def test_identifiers_coeff_exponent_both(self):
         a1, a2 = root = tree('2a3+2a3')
         possibilities = match_combine_polynomes(root)
         self.assertEqualPos(possibilities,
                 [P(root, combine_polynomes, (a1, a2, 2, 2, 'a', 3))])
 
-
     def test_basic_subexpressions(self):
         a_b, c, d = tree('a+b,c,d')
         left, right = root = tree('(a+b)^d + (a+b)^d')