| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- from node import Leaf
- def generate_line(root):
- """
- Print an expression tree in a single text line. Where needed, add
- parentheses.
- >>> from node import Node, Leaf
- >>> l0, l1 = Leaf(1), Leaf(2)
- >>> plus = Node('+', l0, l1)
- >>> print generate_line(plus)
- 1 + 2
- >>> plus2 = Node('+', l0, l1)
- >>> times = Node('*', plus, plus2)
- >>> print generate_line(times)
- (1 + 2) * (1 + 2)
- >>> l2 = Leaf(3)
- >>> uminus = Node('-', l2)
- >>> times = Node('*', plus, uminus)
- >>> print generate_line(times)
- (1 + 2) * -3
- >>> exp = Leaf('x')
- >>> inf = Leaf('oo')
- >>> minus_inf = Node('-', inf)
- >>> integral = Node('int', exp, minus_inf, inf)
- >>> print generate_line(integral)
- int(x, -oo, oo)
- """
- operators = [
- ('+', '-'),
- ('*', '/', 'mod'),
- ('^', )
- ]
- max_pred = len(operators)
- def is_operator(node):
- """
- Check if a given node is an operator (otherwise, it's a function).
- """
- label = node.title()
- either = lambda a, b: a or b
- return reduce(either, map(lambda x: label in x, operators))
- def pred(node):
- """
- Get the precedence of an operator node.
- """
- # Check binary and n-ary operators
- if not isinstance(node, Leaf) and len(node) > 1:
- op = node.title()
- for i, group in enumerate(operators):
- if op in group:
- return i
- # Unary operator and leaves have highest precedence
- return max_pred
- def traverse(node):
- """
- The expression tree is traversed using preorder traversal:
- 1. Visit the root
- 2. Traverse the subtrees in left-to-right order
- """
- if not node:
- return '<empty expression>'
- op = node.title()
- if not node.nodes:
- return op
- arity = len(node)
- if is_operator(node):
- if arity == 1:
- # Unary operator
- result = op + traverse(node[0])
- else:
- # N-ary operator
- node_pred = pred(node)
- result = ''
- sep = ' ' + op + ' '
- e = []
- for i, child in enumerate(node):
- exp = traverse(child)
- # Check if there is a precedence conflict
- # If so, add parentheses
- child_pred = pred(child)
- if child_pred < node_pred or \
- (i and child_pred == node_pred and op != child.title()):
- exp = '(' + exp + ')'
- e.append(exp)
- # Check if a multiplication sign is nessecary
- if op == '*':
- left, right = node
- # Get the previous multiplcation element if the arity is
- # greater than 2
- if left.title() == '*':
- left = left[1]
- # a * b -> ab
- # a * 2 -> a * 2
- # a * (b) -> a(b)
- # 2 * a -> 2a
- if (is_id(left) and (is_id(right) or e[1][0] == '(')) \
- or (is_int(left) and is_id(right)):
- sep = ''
- result += sep.join(e)
- else:
- # Function
- result = op + '(' + ', '.join(map(traverse, node)) + ')'
- return result
- return traverse(root)
- def is_id(node):
- return isinstance(node, Leaf) and not node.title().isdigit()
- def is_int(node):
- return isinstance(node, Leaf) and node.title().isdigit()
|