Przeglądaj źródła

Reimplemented line printer using precedences in accordance with the parser.

- All special cases for negation have been removed, except when they are also
  present in the parser.
- Some callbacks have been added to be able to customize output for certain
  operators in an extension of the Node class. This prevents chaotic sceneries
  in the graph_drawing library.
- Some useful unit tests have been added to test the precedence- and
  callback-meschanisms.
Taddeus Kroes 13 lat temu
rodzic
commit
e17779f211
4 zmienionych plików z 366 dodań i 125 usunięć
  1. 225 114
      line.py
  2. 37 2
      node.py
  3. 101 9
      tests/test_line.py
  4. 3 0
      tests/test_node.py

+ 225 - 114
line.py

@@ -1,45 +1,94 @@
 from traverse import traverse_depth_first
+from node import Node
 
 
+all_parens = ('()', '[]', '||', '{}')
 OPERATORS = (
-    ('vv', ),
-    ('^^', ),
-    ('=', ),
-    ('+', '-'),
-    ('*', 'mod'),
-    ('/', ),
-    ('^', '_'),
+    ('left', ('vv', )),
+    ('left', ('^^', )),
+    ('left', ('=', )),
+    ('left', ('+', '-')),
+    ('nonassoc', ('int', 'd/d')),
+    ('left', ('*', 'mod')),
+    ('left', ('/', )),
+    ('nonassoc', ('\'', )),
+    ('nonassoc', ('neg', )),
+    ('nonassoc', ('function', )),
+    ('right', ('^', '_')),
+    ('nonassoc', all_parens),
 )
 
-NEG_PRED = 3
+assocs = {}
+preds = {}
+
+for i, (assoc, ops) in enumerate(OPERATORS):
+    for op in ops:
+        assocs[op] = assoc
+        preds[op] = i
+
+NEG_PRED = preds['neg']
+FUNC_PRED = preds['function']
 MAX_PRED = len(OPERATORS)
 
 
-def is_operator(node):
+def is_function(node):
     """
-    Check if a given node is an operator (otherwise, it's a function).
+    Check if a given node is a function. A node is considered a function if it
+    is not a leaf, and is not known in the precedences list above.
     """
-    label = node.title()
-
-    return any(map(lambda x: label in x, OPERATORS))
+    return not node.is_leaf and node.title() not in preds
 
 
 def pred(node):
     """
-    Get the precedence of an operator node.
+    Get the operator precedence of a node. Leaf nodes have the highest
+    precedence for practical reasons.
     """
-    # Check binary and n-ary operators
-    if not node.is_leaf and len(node) > 1:
+    # Check known operators
+    if not node.is_leaf:
         op = node.title()
 
-        for i, group in enumerate(OPERATORS):
-            if op in group:
-                return i
+        if node.is_negation():
+            if node[0].title() in '*/':
+                return preds['-']
+
+            return NEG_PRED
+
+        #if node.is_postfix() and not node[0].is_leaf \
+        #        and node[0].title() in all_parens:
+        #    return preds['()']
+
+        if op in preds:
+            return preds[op]
+
+        return FUNC_PRED
 
     # Unary operator and leaves have highest precedence
     return MAX_PRED
 
 
+def rightmost_node(node):
+    if node.is_leaf or not len(node):
+        return node
+
+    return rightmost_node(node[-1])
+
+
+def is_unary_prefix(node):
+    """
+    Check if a node is a unary operator that is placed before the operand.
+    """
+    return not node.is_leaf and len(node) == 1 and node.title() == '-'
+
+
+def is_left_assoc(op):
+    return op in assocs and assocs[op] == 'left'
+
+
+def is_right_assoc(op):
+    return op in assocs and assocs[op] == 'right'
+
+
 def is_id(node):
     return node.is_leaf and not node.title().isdigit()
 
@@ -52,6 +101,24 @@ def is_power(node):
     return not node.is_leaf and node.title() == '^'
 
 
+def preprocess_node(node):
+    node = node.clone()
+    node.preprocess_str_exp()
+
+    if node.negated:
+        node.negated -= 1
+        return Node('-', preprocess_node(node))
+
+    if not node.is_leaf:
+        for i, child in enumerate(node):
+            node[i] = preprocess_node(child)
+
+        if node.title() == '+' and node[1].is_negation():
+            return Node('-', node[0], node[1][0])
+
+    return node
+
+
 def generate_line(root):
     """
     Print an expression tree in a single text line. Where needed, add
@@ -77,13 +144,6 @@ def generate_line(root):
     >>> print generate_line(times)
     (1 + 2) * -3
 
-    >>> exp = Leaf('x')
-    >>> inf = Leaf('oo')
-    >>> minus_inf = Node('-', inf)
-    >>> integral = Node('int', exp, minus_inf, inf)
-    >>> print generate_line(integral)
-    int(x, -oo, oo)
-
     >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
     >>> print generate_line(minus)
     2 - -15x
@@ -107,124 +167,175 @@ def generate_line(root):
 
     content = {}
 
+    def mult_sign(left, right, lparens, rparens):
+        # Get the previous multiplication element in an nary multiplication
+        if left.title() == '*':
+            left = rightmost_node(left)
+
+        # a * b -> ab
+        # a * 2 -> a * 2
+        # a * (b) -> a(b)
+        # (a) * b -> (a)b
+        # (a) * (b) -> (a)(b)
+        # 2 * a -> 2a
+        # a * sin(b) -> a sin(b)
+        left_char = content[left][-1]
+        right_char = content[right][0]
+        left_paren = lparens or left_char in ')]}'
+        right_paren = rparens or right_char in '([{'
+        right_alpha = right_char.isalpha()
+        left_simple = is_id(left) or is_int(left)
+
+        if left_paren or (right_paren and left_simple) \
+                or (is_id(left) and is_id(right)) \
+                or (is_int(left) and right_alpha):
+            return ''
+
+        if is_id(left) and right_alpha:
+            return ' '
+
+        return ' * '
+
     def construct_unary(node):
         op = node.title()
         value = node[0]
+        strval = content[value]
 
-        # -a
-        # -3 * 4
-        # --a
-        if value.is_leaf \
-                or not ' ' in content[value] or pred(value) > NEG_PRED:
-            return op + content[value]
+        if op in ('()', '[]', '||', '{}'):
+            return op[0] + strval + op[1]
 
-        # -(a + b)
-        return '%s(%s)' % (op, content[value])
+        parens = False
 
+        if pred(value) < pred(node):
+            parens = not value.is_negation()
+        elif pred(value) == pred(node):
+            parens = len(value) > 1
 
-    def construct_nary(node):
+        if parens:
+            strval = '(' + strval + ')'
+
+        if node.is_postfix():
+            return strval + node.operator()
+
+        prefix = node.operator()
+
+        if prefix != '-' and not (strval[0] in '([|{' and is_function(node)):
+            prefix += ' '
+
+        return prefix + strval
+
+    def construct_binary(node):
         op = node.title()
+        op_pred = pred(node)
 
-        # N-ary operator
-        node_pred = pred(node)
-        sep = ' ' + op + ' '
-        e = []
+        if node.no_spacing:
+            sep = node.operator()
+        else:
+            sep = ' ' + node.operator() + ' '
 
-        for i, child in enumerate(node):
-            exp = content[child]
+        left, right = node
+        lstr = content[left]
+        rstr = content[right]
+        lpred = pred(left)
+        rpred = pred(right)
+        lparens = rparens = False
+        unary_right = is_unary_prefix(right)
 
-            #if i and op == '+' and exp[:2] == '-(':
-            #    exp = '-' + exp[2:-1]
-            #    print 'exp:', exp
-
-            # Check if there is a precedence conflict
-            # If so, add parentheses
-            child_pred = pred(child)
-
-            if child.negated:
-                # (-a) ^ b
-                # -a ^ -b
-                # (-a) * b
-                # a * -b
-                # (-a) / b
-                if node_pred > NEG_PRED:
-                    exp = '(' + exp + ')'
-            elif child_pred < node_pred:
-                exp = '(' + exp + ')'
-            elif child_pred == node_pred:
-                if i and (op != child.title() or op == '/' \
-                          or (op == '+' and child[1].negated)):
-                    exp = '(' + exp + ')'
-                elif not i and op == '^':
-                    exp = '(' + exp + ')'
+        if lpred < op_pred or (op in '*/' and left.is_negation()):
+            lparens = True
+        elif lpred == op_pred:
+            lparens = is_right_assoc(left.title()) or is_right_assoc(op)
 
-            e.append(exp)
+        if rpred < op_pred:
+            rparens = not unary_right
+        elif rpred == op_pred and len(right) > 1:
+            if right.title() == op:
+                rparens = not is_right_assoc(op)
+            elif is_left_assoc(right.title()):
+                rparens = True
 
-        if op == '*':
-            # Check if an explicit multiplication sign is nessecary
-            left, right = node
+        # Check if multiplication sig is necessary
+        if op == '*' and not unary_right:
+            sep = mult_sign(left, right, lparens, rparens)
 
-            # Get the previous multiplication element if the arity is
-            # greater than 2
-            if left.title() == '*':
-                left = left[1]
+        if lparens:
+            lstr = '(' + lstr + ')'
 
-            # a * b -> ab
-            # a * 2 -> a * 2
-            # a * (b) -> a(b)
-            # (a) * b -> (a)b
-            # (a) * (b) -> (a)(b)
-            # 2 * a -> 2a
-            l = e[0][-1]
-            r = e[1][0]
-            left_simple = is_id(left) or is_int(left)
+        if rparens:
+            rstr = '(' + rstr + ')'
 
-            if (r in ('(', '[') and left_simple) or (l == ')' and r != '-') \
-                    or (left_simple and r.isalpha()):
-                sep = ''
+        return lstr + sep + rstr
 
-        exp = sep.join(e)
+    def construct_nary_mult(node):
+        op_pred = pred(node)
+        lstr = content[node[0]]
+        lparens = pred(node[0]) < op_pred or node[0].is_negation()
 
-        if node.negated and op not in ('*', '/', '^'):
-            exp = '(' + exp + ')'
+        if lparens:
+            lstr = '(' + lstr + ')'
 
-        return exp
+        for i, right in enumerate(node[1:]):
+            rparens = pred(right) < op_pred
+            rstr = content[right]
 
-    def construct_function(node):
-        buf = []
+            if rparens:
+                rstr = '(' + rstr + ')'
+
+            sign = mult_sign(node[i], right, lparens, rparens)
+            lstr += sign + rstr
+            lparens = rparens
+
+        return lstr
+
+    def construct_nary(node):
+        if node.title() == '*':
+            return construct_nary_mult(node)
+
+        op_pred = pred(node)
+        e = []
 
         for child in node:
-            buf.append(content[child])
+            exp = content[child]
+
+            if pred(child) < op_pred:
+                exp = '(' + exp + ')'
 
-        return '%s(%s)' % (node.title(), ', '.join(buf))
+            e.append(exp)
+
+        return (' ' + node.operator() + ' ').join(e)
+
+    def construct_function(node):
+        children = [content[child] for child in node]
+        return '%s(%s)' % (node.operator(), ', '.join(children))
+
+    # Convert negations to unary nodes to be able to account for operator
+    # precedence
+    root = preprocess_node(root.clone())
 
     # Traverse the expression tree and construct the mathematical expression in
     # the leafs and nodes in depth first order.
     for node in traverse_depth_first(root):
+        custom = node.custom_line()
+
+        if custom is not None:
+            content[node] = custom
+            continue
+
         if node.is_leaf:
-            content[node] = str(node)
+            nodestr = str(node.value)
         else:
-            arity = len(node)
-
-            if is_operator(node):
-                if arity == 1:
-                    content[node] = construct_unary(node)
-                else:
-                    content[node] = construct_nary(node)
+            arity = node.arity()
+
+            if arity == 1:
+                nodestr = construct_unary(node)
+            elif is_function(node):
+                nodestr = construct_function(node)
+            elif arity == 2:
+                nodestr = construct_binary(node)
             else:
-                result = None
-
-                if hasattr(node, 'construct_function'):
-                    children = [content[c] for c in node]
-                    result = node.construct_function(children)
-
-                if result == None:
-                    result = construct_function(node)
-
-                content[node] = result
+                nodestr = construct_nary(node)
 
-            # Add negations
-            content[node] = '-' * node.negated + content[node]
+        content[node] = node.postprocess_str(nodestr)
 
-    # Merge binary plus and unary minus signs into binary minus.
+    # Merge binary plus and unary minus signs into a binary minus
     return content[root].replace('+ -', '- ')

+ 37 - 2
node.py

@@ -7,6 +7,7 @@ class Node(object):
         super(Node, self).__init__()
         self.value, self.nodes = value, list(nodes)
         self.is_leaf = False
+        self.no_spacing = kwargs.get('no_spacing', False)
         self.negated = kwargs.get('negated', 0)
 
     def __getitem__(self, n):
@@ -26,11 +27,17 @@ class Node(object):
                and self.nodes == node.nodes
 
     def __neg__(self):
-        copied = deepcopy(self)
+        copied = self.clone()
         copied.negated += 1
 
         return copied
 
+    def __pos__(self):
+        copied = self.clone()
+        copied.negated = max(copied.negated - 1, 0)
+
+        return copied
+
     def __str__(self):
         return '<Node value=%s nodes=%s negated=%d>' \
                % (str(self.value), str(self.nodes), self.negated)
@@ -41,11 +48,39 @@ class Node(object):
     def title(self):
         return str(self.value)
 
+    def operator(self):
+        return self.value
+
+    def clone(self):
+        return deepcopy(self)
+
+    def arity(self):
+        return len(self)
+
+    def is_postfix(self):
+        return self.value == '\''
+
+    def is_negation(self):
+        return self.value == '-' and len(self) == 1
+
+    def custom_line(self):
+        pass
+
+    def preprocess_str_exp(self):
+        pass
+
+    def postprocess_str(self, string):
+        return string
+
 
 class Leaf(Node):
     def __init__(self, value, **kwargs):
         super(Leaf, self).__init__(value, **kwargs)
-        self.value = value
+
+        if type(value) in (int, float) and value < 0:
+            self.value = abs(value)
+            self.negated += 1
+
         self.nodes = None
         self.is_leaf = True
 

+ 101 - 9
tests/test_line.py

@@ -1,5 +1,6 @@
 import unittest
 import doctest
+import new
 
 import line
 from node import Node as N, Leaf as L
@@ -38,6 +39,8 @@ class TestLine(unittest.TestCase):
         minus = N('-', l0, plus)
         self.assertEquals(generate_line(minus), '1 - (2 + 3)')
 
+        power = N('^', l0, N('_', l1, l2))
+        self.assertEquals(generate_line(power), '1 ^ 2 _ 3')
         power = N('^', l0, N('^', l1, l2))
         self.assertEquals(generate_line(power), '1 ^ 2 ^ 3')
         power = N('^', N('^', l0, l1), l2)
@@ -63,11 +66,8 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(plus), '1 + 2 + 3')
 
     def test_function(self):
-        exp = L('x')
-        inf = L('oo')
-        minus_inf = -L('oo')
-        integral = N('int', exp, minus_inf, inf)
-        self.assertEquals(generate_line(integral), 'int(x, -oo, oo)')
+        sin = N('sin', N('*', L(2), L('x')))
+        self.assertEquals(generate_line(sin), 'sin(2x)')
 
     def test_mod(self):
         l0, l1 = L(1), L(2)
@@ -77,7 +77,7 @@ class TestLine(unittest.TestCase):
     def test_multiplication_identifiers(self):
         a, b = L('a'), L('b')
         self.assertEquals(generate_line(N('*', a, b)), 'ab')
-        self.assertEquals(generate_line(N('*', a, -b)), 'a(-b)')
+        self.assertEquals(generate_line(N('*', a, -b)), 'a * -b')
 
     def test_multiplication_constant_identifier(self):
         l0, a = L(2), L('a')
@@ -206,8 +206,14 @@ class TestLine(unittest.TestCase):
         neg = -N('-', L(1), L(2))
         self.assertEquals(generate_line(neg), '-(1 - 2)')
 
+        # FIXME: neg = N('+', L(1), N('+', L(1), L(2)))
+        # FIXME: self.assertEquals(generate_line(neg), '1 + 1 + 2')
+
+        neg = N('+', N('+', L(1), L(2)), L(3))
+        self.assertEquals(generate_line(neg), '1 + 2 + 3')
+
         neg = N('+', L(1), N('+', L(1), L(2)))
-        self.assertEquals(generate_line(neg), '1 + 1 + 2')
+        self.assertEquals(generate_line(neg), '1 + (1 + 2)')
 
         neg = N('+', L(1), -N('+', L(1), L(2)))
         self.assertEquals(generate_line(neg), '1 - (1 + 2)')
@@ -219,7 +225,7 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(neg), '-4a')
 
         neg = N('*', L(4), -L('a'))
-        self.assertEquals(generate_line(neg), '4(-a)')
+        self.assertEquals(generate_line(neg), '4 * -a')
 
         neg = -N('*', L(4), L(5))
         self.assertEquals(generate_line(neg), '-4 * 5')
@@ -234,11 +240,15 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(plus), 'a / b - c / d')
 
         mul = N('*', N('+', L('a'), L('b')), -N('+', L('c'), L('d')))
-        self.assertEquals(generate_line(mul), '(a + b)(-(c + d))')
+        self.assertEquals(generate_line(mul), '(a + b) * -(c + d)')
 
     def test_double_negation(self):
         neg = --L(1)
         self.assertEquals(generate_line(neg), '--1')
+        neg = --N('*', L('x'), L(2))
+        self.assertEquals(generate_line(neg), '--x * 2')
+        neg = --N('^', L('x'), L(2))
+        self.assertEquals(generate_line(neg), '--x ^ 2')
 
     def test_divide_fractions(self):
         a, b, c, d = L('a'), L('b'), L('c'), L('d')
@@ -246,3 +256,85 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(div), 'a / (b / c)')
         div = N('/', N('/', a, b), N('/', c, d))
         self.assertEquals(generate_line(div), 'a / b / (c / d)')
+
+    def test_prime(self):
+        a, b, c, d = L('a'), L('b'), L('c'), L('d')
+        root = N('*', a, N("'", b))
+        self.assertEquals(generate_line(root), "a b'")
+        root = N("'", -a)
+        self.assertEquals(generate_line(root), "-a'")
+        root = -N("'", a)
+        self.assertEquals(generate_line(root), "-(a')")
+        root = N("'", N('*', a, b))
+        self.assertEquals(generate_line(root), "(ab)'")
+        root = N("'", N('/', a, b))
+
+    def test_function(self):
+        root = N('sin', L('x'))
+        self.assertEquals(generate_line(root), 'sin x')
+        root = N('sin', N('+', L('x'), L(2)))
+        self.assertEquals(generate_line(root), 'sin(x + 2)')
+        root = N('dummyfunc', L('x'), L(2))
+        self.assertEquals(generate_line(root), 'dummyfunc(x, 2)')
+
+    def test_no_spacing(self):
+        root = N('+', L('x'), L(2), no_spacing=True)
+        self.assertEquals(generate_line(root), 'x+2')
+
+    def test_explicit_parentheses(self):
+        root = N('[]', L('x'))
+        self.assertEquals(generate_line(root), '[x]')
+        root = N('()', L('x'))
+        self.assertEquals(generate_line(root), '(x)')
+        root = N('{}', L('x'))
+        self.assertEquals(generate_line(root), '{x}')
+
+        root = N('^', N('[]', N('+', L('x'), L('y'))), L(2))
+        self.assertEquals(generate_line(root), '[x + y] ^ 2')
+
+    def test_abs(self):
+        root = N('||', L('x'))
+        self.assertEquals(generate_line(root), '|x|')
+        root = N('||', N('+', L('x'), L(1)))
+        self.assertEquals(generate_line(root), '|x + 1|')
+        root = N('ln', N('||', L('x')))
+        self.assertEquals(generate_line(root), 'ln|x|')
+
+    def test_postprocess_str(self):
+        root = N('int', N('^', L('x'), L(2)), L('x'))
+        root.arity = lambda: 1
+        root.postprocess_str = lambda s: s + ' dx'
+        self.assertEquals(generate_line(root), 'int x ^ 2 dx')
+
+    def test_concat_with_negation(self):
+        root = N('*', -L(2), L('x'))
+        self.assertEquals(generate_line(root), '(-2)x')
+        root = N('*', N('*', L(3), -L(2)), L('x'))
+        self.assertEquals(generate_line(root), '3 * -2x')
+        root = N('*', L(3), -L(2), L('x'))
+        self.assertEquals(generate_line(root), '3 * -2 * x')
+
+    def test_first_child_negation(self):
+        root = N('*', -L(1), L(2))
+        self.assertEquals(generate_line(root), '(-1)2')
+        root = -N('*', L(1), L(2))
+        self.assertEquals(generate_line(root), '-1 * 2')
+        root = N('/', -L(1), L(2))
+        self.assertEquals(generate_line(root), '(-1) / 2')
+        root = -N('/', L(1), L(2))
+        self.assertEquals(generate_line(root), '-1 / 2')
+
+    def test_postfix_brackets(self):
+        root = N('*', L('x'), N("'", N('[]', N('^', L('x'), L(2)))))
+        self.assertEquals(generate_line(root), "x[x ^ 2]'")
+
+    def test_custom_line(self):
+        root = N('*', L(1), L(2))
+        root.custom_line = lambda: 'test'
+        self.assertEquals(generate_line(root), 'test')
+
+    def test_preprocess_str_exp(self):
+        root = N('-', L(1))
+        def addbrackets(self): self[0] = N('[]', self[0])
+        root.preprocess_str_exp = new.instancemethod(addbrackets, root)
+        self.assertEquals(generate_line(root), '-[1]')

+ 3 - 0
tests/test_node.py

@@ -51,3 +51,6 @@ class TestNode(unittest.TestCase):
         self.assertEqual(Node('+', l1, l2, negated=1).negated, 1)
 
         self.assertEqual(Leaf(1, negated=2).negated, 2)
+
+    def test_negated_int_constructor(self):
+        self.assertEquals(-Leaf(2), Leaf(-2))