Prechádzať zdrojové kódy

Merge branch 'negated' of kompiler.org:trs

Taddeus Kroes 14 rokov pred
rodič
commit
d419cb36e1

+ 1 - 1
external/graph_drawing

@@ -1 +1 @@
-Subproject commit 11940973bdfef9432438b054c65b28af2eb97d0c
+Subproject commit 84ad376b81ac72e163bacd7b538df16cac9be153

+ 33 - 33
src/node.py

@@ -43,7 +43,7 @@ TYPE_MAP = {
 OP_MAP = {
         '+': OP_ADD,
         # Either substraction or negation. Skip the operator sign in 'x' (= 2).
-        '-': lambda x: OP_SUB if len(x) > 2 else OP_NEG,
+        '-': OP_SUB,
         '*': OP_MUL,
         '/': OP_DIV,
         '^': OP_POW,
@@ -60,6 +60,10 @@ def to_expression(obj):
 
 
 class ExpressionBase(object):
+
+    def __init__(self, *args, **kwargs):
+        self.negated = 0
+
     def clone(self):
         return copy.deepcopy(self)
 
@@ -86,16 +90,11 @@ class ExpressionBase(object):
         if self.is_leaf:
             if other.is_leaf:
                 # Both are leafs, string compare the value.
-                return str(self.value) < str(other.value)
-            # Self is a leaf, thus has less value than an expression node.
-            return True
+                self_value = '-' * (self.negated & 1) + str(self.value)
+                other_value = '-' * (other.negated & 1) + str(other.value)
+
+                return self_value < other_value
 
-        if self.is_op(OP_NEG) and self[0].is_leaf:
-            if other.is_leaf:
-                # Both are leafs, string compare the value.
-                return ('-' + str(self.value)) < str(other.value)
-            if other.is_op(OP_NEG) and other[0].is_leaf:
-                return ('-' + str(self.value)) < ('-' + str(other.value))
             # Self is a leaf, thus has less value than an expression node.
             return True
 
@@ -113,24 +112,6 @@ class ExpressionBase(object):
     def is_op(self, op):
         return not self.is_leaf and self.op == op
 
-    def is_op_or_negated(self, op):
-        if self.is_leaf:
-            return False
-
-        if self.op == OP_NEG:
-            return self[0].is_op(op)
-
-        return self.op == op
-
-    def is_leaf_or_negated(self):
-        if  self.is_leaf:
-            return True
-
-        if self.is_op(OP_NEG):
-            return self[0].is_leaf
-
-        return False
-
     def is_power(self):
         return not self.is_leaf and self.op == OP_POW
 
@@ -164,8 +145,13 @@ class ExpressionBase(object):
     def __pow__(self, other):
         return ExpressionNode('^', self, to_expression(other))
 
-    def __neg__(self):
-        return ExpressionNode('-', self)
+    def reduce_negation(self, n=1):
+        """Remove n negation flags from the node."""
+        return self.negate(-n)
+
+    def negate(self, n=1):
+        """Negate the node n times."""
+        return negate(self, self.negated + n)
 
 
 class ExpressionNode(Node, ExpressionBase):
@@ -226,8 +212,10 @@ class ExpressionNode(Node, ExpressionBase):
             return (ExpressionLeaf(1), self[0], self[1])
 
         # rule: -r -> (1, r, 1)
-        if self.is_op(OP_NEG):
-            return (ExpressionLeaf(1), -self[0], ExpressionLeaf(1))
+        # rule: --r -> (1, r, 1)
+        # rule: ---r -> (1, r, 1)
+        if self.negated:
+            return (ExpressionLeaf(1), self, ExpressionLeaf(1))
 
         if self.op != OP_MUL:
             return
@@ -309,7 +297,6 @@ class ExpressionNode(Node, ExpressionBase):
 class ExpressionLeaf(Leaf, ExpressionBase):
     def __init__(self, *args, **kwargs):
         super(ExpressionLeaf, self).__init__(*args, **kwargs)
-
         self.type = TYPE_MAP[type(args[0])]
 
     def __eq__(self, other):
@@ -339,6 +326,11 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         # rule: 1 * r ^ 1 -> (1, r, 1)
         return (ExpressionLeaf(1), self, ExpressionLeaf(1))
 
+    def actual_value(self):
+        assert self.is_numeric()
+
+        return (1 - 2 * (self.negated & 1)) * self.value
+
 
 class Scope(object):
 
@@ -409,3 +401,11 @@ def get_scope(node):
             scope.append(child)
 
     return scope
+
+
+def negate(node, n=1):
+    """Negate the given node n times."""
+    node = node.clone()
+    node.negated = n
+
+    return node

+ 19 - 13
src/parser.py

@@ -16,7 +16,7 @@ sys.path.insert(1, EXTERNAL_MODS)
 from pybison import BisonParser, BisonSyntaxError
 from graph_drawing.graph import generate_graph
 
-from node import TYPE_OPERATOR, OP_COMMA
+from node import TYPE_OPERATOR, OP_COMMA, OP_NEG
 from rules import RULES
 from possibilities import filter_duplicates, pick_suggestion, apply_suggestion
 
@@ -180,11 +180,13 @@ class Parser(BisonParser):
         return data
 
     def hook_handler(self, target, option, names, values, retval):
-        if target in ['exp', 'line', 'input'] or not retval \
-                or retval.type != TYPE_OPERATOR:
+        if target in ['exp', 'line', 'input'] or not retval:
             return retval
 
-        if self.subtree_map:
+        if not retval.negated and retval.type != TYPE_OPERATOR:
+            return retval
+
+        if self.subtree_map and retval.type == TYPE_OPERATOR:
             # Update the subtree map to let the subtree point to its parent
             # node.
             parent_nodes = self.subtree_map.keys()
@@ -193,10 +195,15 @@ class Parser(BisonParser):
                 if child in parent_nodes:
                     self.subtree_map[child] = retval
 
-        if retval.op not in RULES:
-            return retval
+        if retval.type == TYPE_OPERATOR and retval.op in RULES:
+            handlers = RULES[retval.op]
+        else:
+            handlers = []
+
+        if retval.negated:
+            handlers += RULES[OP_NEG]
 
-        for handler in RULES[retval.op]:
+        for handler in handlers:
             possibilities = handler(retval)
 
             # Record the subtree root node in order to avoid tree traversal.
@@ -343,7 +350,9 @@ class Parser(BisonParser):
         """
 
         if option == 0:  # rule: NEG exp
-            return Node('-', values[1])
+            node = values[1]
+            node.negated += 1
+            return node
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
                                % (option, target))  # pragma: nocover
@@ -361,11 +370,8 @@ class Parser(BisonParser):
             return Node(values[1], values[0], values[2])
 
         if option == 4:  # rule: exp MINUS exp
-            # It is necessary to call the hook_handler here explicitly, since
-            # the minus operator is internally represented as two nodes (unary
-            # negation and binary plus).
-            node = Node('-', values[2])
-            node = self.hook_handler(target, option, names, values, node)
+            node = values[2]
+            node.negated += 1
             return Node('+', values[0], node)
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'

+ 2 - 2
src/rules/factors.py

@@ -1,6 +1,6 @@
 from itertools import product, combinations
 
-from ..node import Scope, OP_ADD, OP_MUL, OP_NEG
+from ..node import Scope, OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -18,7 +18,7 @@ def match_expand(node):
     additions = []
 
     for n in Scope(node):
-        if n.is_leaf or n.is_op(OP_NEG) and n[0].is_leaf:
+        if n.is_leaf:
             leaves.append(n)
         elif n.op == OP_ADD:
             additions.append(n)

+ 9 - 33
src/rules/fractions.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from .utils import least_common_multiple
-from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL, OP_NEG
+from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -80,22 +80,11 @@ def match_add_constant_fractions(node):
 
     p = []
 
-    def is_division(node):
-        return node.is_op(OP_DIV) or \
-                (node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
-
-    fractions = filter(is_division, Scope(node))
+    fractions = filter(lambda node: node.is_op(OP_DIV), Scope(node))
 
     for a, b in combinations(fractions, 2):
-        if a.is_op(OP_NEG):
-            na, da = a[0]
-        else:
-            na, da = a
-
-        if b.is_op(OP_NEG):
-            nb, db = b[0]
-        else:
-            nb, db = b
+        na, da = a
+        nb, db = b
 
         if da == db:
             # Equal denominators, add nominators to create a single fraction
@@ -116,20 +105,17 @@ def equalize_denominators(root, args):
     a / 2 + b / 4  ->  2a / 4 + b / 4
     """
     denom = args[2]
-
     scope = Scope(root)
 
     for fraction in args[:2]:
-        n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction
+        n, d = fraction
         mult = denom / d.value
 
         if mult != 1:
             n = L(n.value * mult) if n.is_numeric() else L(mult) * n
 
-            if fraction.is_op(OP_NEG):
-                scope.remove(fraction, -(n / L(d.value * mult)))
-            else:
-                scope.remove(fraction, n / L(d.value * mult))
+            scope.remove(fraction, negate(n / L(d.value * mult),
+                                          fraction.negated))
 
     return scope.as_nary_node()
 
@@ -147,21 +133,11 @@ def add_nominators(root, args):
     """
     # TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"?
     ab, cb = args
-
-    if ab.is_op(OP_NEG):
-        a, b = ab[0]
-    else:
-        a, b = ab
-
-    if cb.is_op(OP_NEG):
-        c = -cb[0][0]
-    else:
-        c = cb[0]
-
+    a, b = ab
     scope = Scope(root)
 
     # Replace the left node with the new expression
-    scope.remove(ab, (a + c) / b)
+    scope.remove(ab, (a + negate(cb[0], cb.negated)) / b)
 
     # Remove the right node
     scope.remove(cb)

+ 2 - 4
src/rules/groups.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
-        OP_ADD, OP_MUL, OP_NEG
+        OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -18,7 +18,6 @@ def match_combine_groups(node):
     ab + 2ab  ->  3ab
     ab + ba   ->  2ab
     """
-    # TODO: handle OP_NEG nodes
     assert node.is_op(OP_ADD)
 
     p = []
@@ -34,8 +33,7 @@ def match_combine_groups(node):
             l = len(scope)
 
             for i, sub_node in enumerate(scope):
-                if sub_node.is_numeric() or (sub_node.is_op(OP_NEG)
-                                             and sub_node[0].is_numeric()):
+                if sub_node.is_numeric():
                     others = [scope[j] for j in range(i) + range(i + 1, l)]
 
                     if len(others) == 1:

+ 27 - 38
src/rules/negation.py

@@ -1,4 +1,4 @@
-from ..node import get_scope, nary_node, OP_NEG, OP_ADD, OP_MUL, OP_DIV
+from ..node import get_scope, nary_node, OP_ADD, OP_MUL, OP_DIV
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -6,58 +6,48 @@ from ..translate import _
 def match_negate_group(node):
     """
     --a                 ->  a
-    --ab                ->  ab
-    -(-ab + c)          ->  --ab - c
+    -(a * ... * -b)     ->  ab
     -(a + b + ... + z)  ->  -a + -b + ... + -z
     """
-    assert node.is_op(OP_NEG)
+    assert node.negated
 
-    val = node[0]
-
-    if val.is_op(OP_NEG):
+    if node.negated == 2:
         # --a
         return [P(node, double_negation, (node,))]
 
-    if not val.is_leaf:
-        scope = get_scope(val)
-
-        if not any(map(lambda n: n.is_op(OP_NEG), scope)):
-            return []
+    if not node.is_leaf:
+        scope = get_scope(node)
 
-        if val.is_op(OP_MUL):
-            # --ab
-            return [P(node, negate_polynome, (node, scope))]
+        if node.is_op(OP_MUL) and any(map(lambda n: n.negated, scope)):
+            # -(-a)b
+            return [P(node, negate_group, (node, scope))]
 
-        elif val.is_op(OP_ADD):
+        if node.is_op(OP_ADD):
             # -(ab + c)   ->  -ab - c
             # -(-ab + c)  ->  ab - c
-            return [P(node, negate_group, (node, scope))]
+            return [P(node, negate_polynome, (node, scope))]
 
     return []
 
 
-def negate_polynome(root, args):
+def negate_group(root, args):
     """
-    # -a * -3c  ->  a * 3c
-    --a * 3c  ->  a * 3c
-    --ab      ->  ab
-    --abc     ->  abc
+    -(a * -3c)       ->  a * 3c
+    -(a * ... * -b)  ->  ab
     """
     node, scope = args
 
     for i, n in enumerate(scope):
-        # XXX: validate this property!
-        if n.is_op(OP_NEG):
-            scope[i] = n[0]
-            return nary_node('*', scope)
+        if n.negated:
+            scope[i] = n.reduce_negation()
 
-    raise RuntimeError('No negation node found in scope.')
+    return nary_node('*', scope).reduce_negation()
 
 
-MESSAGES[negate_polynome] = _('Apply negation to the polynome {1[0]}.')
+MESSAGES[negate_group] = _('Apply negation to the polynome {1[0]}.')
 
 
-def negate_group(root, args):
+def negate_polynome(root, args):
     """
     -(-ab + ... + c)  ->  --ab + ... + -c
     """
@@ -70,16 +60,14 @@ def negate_group(root, args):
     return nary_node('+', scope)
 
 
-MESSAGES[negate_group] = _('Apply negation to the subexpression {1[0]}.')
+MESSAGES[negate_polynome] = _('Apply negation to the subexpression {1[0]}.')
 
 
 def double_negation(root, args):
     """
     --a  ->  a
     """
-    node = args[0]
-
-    return node[0][0]
+    return negate(args[0], args[0].negated - 2)
 
 
 MESSAGES[double_negation] = _('Remove double negation in {1}.')
@@ -92,14 +80,12 @@ def match_negated_division(node):
     assert node.is_op(OP_DIV)
 
     a, b = node
-    a_neg = a.is_op(OP_NEG)
-    b_neg = b.is_op(OP_NEG)
 
-    if a_neg and b_neg:
+    if a.negated and b.negated:
         return [P(node, double_negated_division, (node,))]
-    elif a_neg:
+    elif a.negated:
         return [P(node, single_negated_division, (a[0], b))]
-    elif b_neg:
+    elif b.negated:
         return [P(node, single_negated_division, (a, b[0]))]
 
     return []
@@ -132,3 +118,6 @@ def double_negated_division(root, args):
 
 MESSAGES[double_negated_division] = \
         _('Eliminate top and bottom negation in {1}.')
+
+
+# TODO: negated multiplication: -a * -b = ab

+ 9 - 35
src/rules/numerics.py

@@ -1,6 +1,6 @@
 from itertools import combinations
 
-from ..node import ExpressionLeaf as Leaf, Scope, OP_DIV, OP_MUL, OP_NEG
+from ..node import ExpressionLeaf as Leaf, Scope, negate, OP_DIV, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -16,21 +16,10 @@ def add_numerics(root, args):
     -2 + -3  ->  -5
     """
     n0, n1, c0, c1 = args
-
-    if c0.is_op(OP_NEG):
-        c0 = -c0[0].value
-    else:
-        c0 = c0.value
-
-    if c1.is_op(OP_NEG):
-        c1 = (-c1[0].value)
-    else:
-        c1 = c1.value
-
     scope = Scope(root)
 
     # Replace the left node with the new expression
-    scope.remove(n0, Leaf(c0 + c1))
+    scope.remove(n0, Leaf(c0.actual_value() + c1.actual_value()))
 
     # Remove the right node
     scope.remove(n1)
@@ -119,20 +108,12 @@ def match_multiply_zero(node):
     assert node.is_op(OP_MUL)
 
     left, right = node
-    is_zero = lambda n: n.is_leaf and n.value == 0
-
-    if is_zero(left):
-        negated = right.is_op(OP_NEG)
-    elif is_zero(right):
-        negated = left.is_op(OP_NEG)
-    elif left.is_op(OP_NEG) and is_zero(left[0]):
-        negated = not right.is_op(OP_NEG)
-    elif right.is_op(OP_NEG) and is_zero(right[0]):
-        negated = not left.is_op(OP_NEG)
-    else:
-        return []
 
-    return [P(node, multiply_zero, (negated,))]
+    if (left.is_leaf and left.value == 0) \
+            or (right.is_leaf and right.value == 0):
+        return [P(node, multiply_zero, (left.negated + right.negated,))]
+
+    return []
 
 
 def multiply_zero(root, args):
@@ -143,12 +124,7 @@ def multiply_zero(root, args):
     0 * -a   ->  -0
     -0 * -a  ->  0
     """
-    negated = args[0]
-
-    if negated:
-        return -Leaf(0)
-    else:
-        return Leaf(0)
+    return negate(Leaf(0), args[0])
 
 
 MESSAGES[multiply_zero] = _('Multiplication with zero yields zero.')
@@ -168,9 +144,7 @@ def match_multiply_numerics(node):
 
     for n in Scope(node):
         if n.is_numeric():
-            numerics.append((n, n.value))
-        elif n.is_op(OP_NEG) and n[0].is_numeric():
-            numerics.append((n, -n[0].value))
+            numerics.append((n, n.actual_value()))
 
     for (n0, v0), (n1, v1) in combinations(numerics, 2):
         p.append(P(node, multiply_numerics, (n0, n1, v0, v1)))

+ 2 - 6
src/rules/poly.py

@@ -1,14 +1,10 @@
 from itertools import combinations
 
-from ..node import Scope, OP_ADD, OP_NEG
+from ..node import Scope, OP_ADD
 from ..possibilities import Possibility as P, MESSAGES
 from .numerics import add_numerics
 
 
-def is_numeric_or_negated_numeric(n):
-    return n.is_numeric() or (n.is_op(OP_NEG) and n[0].is_numeric())
-
-
 def match_combine_polynomes(node, verbose=False):
     """
     n + exp + m -> exp + (n + m)
@@ -52,7 +48,7 @@ def match_combine_polynomes(node, verbose=False):
             # roots, or: same root and exponent -> combine coefficients.
             # TODO: Addition with zero, e.g. a + 0 -> a
             if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \
-                    and all(map(is_numeric_or_negated_numeric, [r0, r1])):
+                    and all(map(lambda n: n.is_numeric(), [r0, r1])):
                 # 2 + 3    ->  5
                 # 2 + -3   ->  -1
                 # -2 + 3   ->  1

+ 67 - 67
src/rules/powers.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
 from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
-                   OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
+                   OP_MUL, OP_DIV, OP_POW, OP_ADD
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -91,6 +91,18 @@ def match_subtract_exponents(node):
     return []
 
 
+def subtract_exponents(root, args):
+    """
+    a^p / a^q  ->  a^(p - q)
+    """
+    a, p, q = args
+
+    return a ** (p - q)
+
+
+MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
+
+
 def match_multiply_exponents(node):
     """
     (a^p)^q  ->  a^(pq)
@@ -105,6 +117,18 @@ def match_multiply_exponents(node):
     return []
 
 
+def multiply_exponents(root, args):
+    """
+    (a^p)^q  ->  a^(pq)
+    """
+    a, p, q = args
+
+    return a ** (p * q)
+
+
+MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
+
+
 def match_duplicate_exponent(node):
     """
     (ab)^p  ->  a^p * b^p
@@ -119,20 +143,49 @@ def match_duplicate_exponent(node):
     return []
 
 
+def duplicate_exponent(root, args):
+    """
+    (ab)^p   ->  a^p * b^p
+    (abc)^p  ->  a^p * b^p * c^p
+    """
+    ab, p = args
+    result = ab[0] ** p
+
+    for b in ab[1:]:
+        result *= b ** p
+
+    return result
+
+
+MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
+
+
 def match_remove_negative_exponent(node):
     """
     a^-p  ->  1 / a^p
     """
     assert node.is_op(OP_POW)
 
-    left, right = node
+    a, p = node
 
-    if right.is_op(OP_NEG):
-        return [P(node, remove_negative_exponent, (left, right[0]))]
+    if p.negated:
+        return [P(node, remove_negative_exponent, (a, p))]
 
     return []
 
 
+def remove_negative_exponent(root, args):
+    """
+    a^-p  ->  1 / a^p
+    """
+    a, p = args
+
+    return L(1) / a ** p.reduce_negation()
+
+
+MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
+
+
 def match_exponent_to_root(node):
     """
     a^(1 / m)  ->  sqrt(a, m)
@@ -148,6 +201,16 @@ def match_exponent_to_root(node):
     return []
 
 
+def exponent_to_root(root, args):
+    """
+    a^(1 / m)  ->  sqrt(a, m)
+    a^(n / m)  ->  sqrt(a^n, m)
+    """
+    a, n, m = args
+
+    return N('sqrt', a if n == 1 else a ** n, m)
+
+
 def match_extend_exponent(node):
     """
     (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1)  # n > 1
@@ -174,66 +237,3 @@ def extend_exponent(root, args):
         return left * left ** L(right.value - 1)
 
     return left * left
-
-
-def subtract_exponents(root, args):
-    """
-    a^p / a^q  ->  a^(p - q)
-    """
-    a, p, q = args
-
-    return a ** (p - q)
-
-
-MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
-
-
-def multiply_exponents(root, args):
-    """
-    (a^p)^q  ->  a^(pq)
-    """
-    a, p, q = args
-
-    return a ** (p * q)
-
-
-MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
-
-
-def duplicate_exponent(root, args):
-    """
-    (ab)^p   ->  a^p * b^p
-    (abc)^p  ->  a^p * b^p * c^p
-    """
-    ab, p = args
-    result = ab[0] ** p
-
-    for b in ab[1:]:
-        result *= b ** p
-
-    return result
-
-
-MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
-
-
-def remove_negative_exponent(root, args):
-    """
-    a^-p  ->  1 / a^p
-    """
-    a, p = args
-
-    return L(1) / a ** p
-
-
-MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
-
-
-def exponent_to_root(root, args):
-    """
-    a^(1 / m)  ->  sqrt(a, m)
-    a^(n / m)  ->  sqrt(a^n, m)
-    """
-    a, n, m = args
-
-    return N('sqrt', a if n == 1 else a ** n, m)