Browse Source

Made generate_line use the iterative, depth-first tree walker.

Sander Mathijs van Veen 14 năm trước cách đây
mục cha
commit
11940973bd
3 tập tin đã thay đổi với 191 bổ sung45 xóa
  1. 186 43
      line.py
  2. 1 1
      tests/test_line.py
  3. 4 1
      traverse.py

+ 186 - 43
line.py

@@ -1,7 +1,41 @@
 from node import Leaf
+from traverse import traverse_depth_first
 
+OPERATORS = [
+        ('+', '-'),
+        ('*', '/', 'mod'),
+        ('^', )
+        ]
 
-def generate_line(root):
+MAX_PRED = len(OPERATORS)
+
+
+def is_operator(node):
+    """
+    Check if a given node is an operator (otherwise, it's a function).
+    """
+    label = node.title()
+
+    return any(map(lambda x: label in x, OPERATORS))
+
+
+def pred(node):
+    """
+    Get the precedence of an operator node.
+    """
+    # Check binary and n-ary operators
+    if not node.is_leaf and len(node) > 1:
+        op = node.title()
+
+        for i, group in enumerate(OPERATORS):
+            if op in group:
+                return i
+
+    # Unary operator and leaves have highest precedence
+    return MAX_PRED
+
+
+def generate_line_old(root):
     """
     Print an expression tree in a single text line. Where needed, add
     parentheses.
@@ -9,60 +43,28 @@ def generate_line(root):
     >>> from node import Node, Leaf
     >>> l0, l1 = Leaf(1), Leaf(2)
     >>> plus = Node('+', l0, l1)
-    >>> print generate_line(plus)
+    >>> print generate_line_old(plus)
     1 + 2
 
     >>> plus2 = Node('+', l0, l1)
     >>> times = Node('*', plus, plus2)
-    >>> print generate_line(times)
+    >>> print generate_line_old(times)
     (1 + 2)(1 + 2)
 
     >>> l2 = Leaf(3)
     >>> uminus = Node('-', l2)
     >>> times = Node('*', plus, uminus)
-    >>> print generate_line(times)
+    >>> print generate_line_old(times)
     (1 + 2) * -3
 
     >>> exp = Leaf('x')
     >>> inf = Leaf('oo')
     >>> minus_inf = Node('-', inf)
     >>> integral = Node('int', exp, minus_inf, inf)
-    >>> print generate_line(integral)
+    >>> print generate_line_old(integral)
     int(x, -oo, oo)
     """
 
-    operators = [
-            ('+', '-'),
-            ('*', '/', 'mod'),
-            ('^', )
-            ]
-
-    max_pred = len(operators)
-
-    def is_operator(node):
-        """
-        Check if a given node is an operator (otherwise, it's a function).
-        """
-        label = node.title()
-        either = lambda a, b: a or b
-
-        return reduce(either, map(lambda x: label in x, operators))
-
-    def pred(node):
-        """
-        Get the precedence of an operator node.
-        """
-        # Check binary and n-ary operators
-        if not isinstance(node, Leaf) and len(node) > 1:
-            op = node.title()
-
-            for i, group in enumerate(operators):
-                if op in group:
-                    return i
-
-        # Unary operator and leaves have highest precedence
-        return max_pred
-
     def traverse(node):
         """
         The expression tree is traversed using preorder traversal:
@@ -94,9 +96,9 @@ def generate_line(root):
                 # -(4 * 5)
                 # 1 - 4 * 5
                 # 1 + -(4 * 5)  ->  1 - 4 * 5
-                if ' ' in sub_exp and not (not isinstance(sub, Leaf) \
-                        and hasattr(node, 'marked_negation')
-                        and pred(sub) > 0):
+                if ' ' in sub_exp and sub.is_Leaf \
+                        and hasattr(node, 'marked_negation') \
+                        and pred(sub) > 0:
                     sub_exp = '(' + sub_exp + ')'
 
                 result = op + sub_exp
@@ -166,7 +168,7 @@ def generate_line(root):
 
 
 def is_negation(node):
-    if isinstance(node, Leaf):
+    if node.is_leaf:
         return False
 
     return node.title() == '-' and len(node) == 1
@@ -176,11 +178,152 @@ def is_id(node):
     if is_negation(node):
         return is_id(node[0])
 
-    return isinstance(node, Leaf) and not node.title().isdigit()
+    return node.is_leaf and not node.title().isdigit()
 
 
 def is_int(node):
     if is_negation(node):
         return is_int(node[0])
 
-    return isinstance(node, Leaf) and node.title().isdigit()
+    return node.is_leaf and node.title().isdigit()
+
+
+def generate_line(root):
+    """
+    Print an expression tree in a single text line. Where needed, add
+    parentheses.
+
+    >>> from node import Node, Leaf
+    >>> l0, l1 = Leaf(1), Leaf(2)
+    >>> plus = Node('+', l0, l1)
+    >>> print generate_line(plus)
+    1 + 2
+
+    >>> plus2 = Node('+', l0, l1)
+    >>> times = Node('*', plus, plus2)
+    >>> print generate_line(times)
+    (1 + 2)(1 + 2)
+
+    >>> l2 = Leaf(3)
+    >>> uminus = Node('-', l2)
+    >>> times = Node('*', plus, uminus)
+    >>> print generate_line(times)
+    (1 + 2) * -3
+
+    >>> exp = Leaf('x')
+    >>> inf = Leaf('oo')
+    >>> minus_inf = Node('-', inf)
+    >>> integral = Node('int', exp, minus_inf, inf)
+    >>> print generate_line(integral)
+    int(x, -oo, oo)
+
+    >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
+    >>> print generate_line(minus)
+    2 - -15x
+
+    >>> left = Node('/', Leaf(22), Leaf(77))
+    >>> right = Node('/', Leaf(28), Leaf(77))
+    >>> minus = Node('-', left, right)
+    >>> print generate_line(minus)
+    22 / 77 - 28 / 77
+
+    >>> plus = Node('+', left, Node('-', right))
+    >>> print generate_line(plus)
+    22 / 77 - 28 / 77
+    """
+
+    if not root:
+        return '<empty expression>'
+
+    if root.is_leaf:
+        return root.title()
+
+    content = {}
+
+    def construct_unary(node):
+        result = node.title()
+        value = node[0]
+
+        # -a
+        # -3*4
+        # --a
+        if value.is_leaf \
+                or not ' ' in content[value] or pred(value) > 0:
+            return result + content[value]
+
+        return '%s(%s)' % (result, content[value])
+
+
+    def construct_nary(node):
+        op = node.title()
+
+        # N-ary operator
+        node_pred = pred(node)
+        sep = ' ' + op + ' '
+        e = []
+
+        for i, child in enumerate(node):
+            exp = content[child]
+
+            # Check if there is a precedence conflict
+            # If so, add parentheses
+            child_pred = pred(child)
+
+            if child_pred < node_pred or \
+                    (i and child_pred == node_pred \
+                        and op != child.title()):
+                exp = '(' + exp + ')'
+
+            e.append(exp)
+
+        if op == '*':
+            # Check if an explicit multiplication sign is nessecary
+            left, right = node
+
+            # Get the previous multiplication element if the arity is
+            # greater than 2
+            if left.title() == '*':
+                left = left[1]
+
+            # a * b -> ab
+            # a * 2 -> a * 2
+            # a * (b) -> a(b)
+            # (a) * b -> (a)b
+            # (a) * (b) -> (a)(b)
+            # 2 * a -> 2a
+            left_paren = e[0][-1] == ')'
+            right_paren = e[1][0] == '('
+
+            if (is_id(left) or left_paren or is_int(left)) \
+                    and (is_id(right) or right_paren):
+                sep = ''
+
+        return sep.join(e)
+
+    def construct_function(node):
+        buf = []
+
+        for child in node:
+            buf.append(content[child])
+
+        return '%s(%s)' % (node.title(), ', '.join(buf))
+
+    # Traverse the expression tree and construct the mathematical expression in
+    # the leafs and nodes in depth first order.
+    for node in traverse_depth_first(root):
+        if node.is_leaf:
+            content[node] = node.title()
+            continue
+
+        arity = len(node)
+
+        if is_operator(node):
+            if arity == 1:
+                content[node] = construct_unary(node)
+            else:
+                content[node] = construct_nary(node)
+        else:
+            content[node] = construct_function(node)
+
+    # Merge binary plus and unary minus signs into binary minus.
+    return content[root].replace('+ -', '- ')

+ 1 - 1
tests/test_line.py

@@ -168,7 +168,7 @@ class TestLine(unittest.TestCase):
         self.assertEquals(generate_line(neg), '-4a')
 
         neg = N('-', N('*', L(4), L(5)))
-        self.assertEquals(generate_line(neg), '-(4 * 5)')
+        self.assertEquals(generate_line(neg), '-4 * 5')
 
         plus = N('+', L(1), N('-', N('*', L(4), L(5))))
         self.assertEquals(generate_line(plus), '1 - 4 * 5')

+ 4 - 1
traverse.py

@@ -21,6 +21,9 @@ def traverse_depth_first(root):
     the tree.
 
     >>> from node import Node as N, Leaf as L
+    >>> root = N('*', L(4), L('a'))
+    >>> print map(lambda n: n.title(), traverse_depth_first(root))
+    ['4', 'a', '*']
     >>> root = N('+', N('/', L(1), L(2)), N('*', L(3), L(4)))
     >>> print map(lambda n: n.title(), traverse_depth_first(root))
     ['1', '2', '/', '3', '4', '*', '+']
@@ -42,7 +45,7 @@ def traverse_depth_first(root):
     node = root
 
     while True:
-        while not isinstance(node, Leaf):
+        while not node.is_leaf:
             # Traverse left and save the path
             stack.append(node)
             path.append(0)