from node import Leaf def generate_line(root): """ Print an expression tree in a single text line. Where needed, add parentheses. >>> from node import Node, Leaf >>> l0, l1 = Leaf(1), Leaf(2) >>> plus = Node('+', l0, l1) >>> print generate_line(plus) 1 + 2 >>> plus2 = Node('+', l0, l1) >>> times = Node('*', plus, plus2) >>> print generate_line(times) (1 + 2) * (1 + 2) >>> l2 = Leaf(3) >>> uminus = Node('-', l2) >>> times = Node('*', plus, uminus) >>> print generate_line(times) (1 + 2) * -3 >>> exp = Leaf('x') >>> inf = Leaf('oo') >>> minus_inf = Node('-', inf) >>> integral = Node('int', exp, minus_inf, inf) >>> print generate_line(integral) int(x, -oo, oo) """ operators = [ ('+', '-'), ('*', '/', 'mod'), ('^', ) ] max_pred = len(operators) def is_operator(node): """ Check if a given node is an operator (otherwise, it's a function). """ label = node.title() either = lambda a, b: a or b return reduce(either, map(lambda x: label in x, operators)) def pred(node): """ Get the precedence of an operator node. """ # Check binary and n-ary operators if not isinstance(node, Leaf) and len(node) > 1: op = node.title() for i, group in enumerate(operators): if op in group: return i # Unary operator and leaves have highest precedence return max_pred def traverse(node): """ The expression tree is traversed using preorder traversal: 1. Visit the root 2. Traverse the subtrees in left-to-right order """ if not node: return '' op = node.title() if not node.nodes: return op arity = len(node) if is_operator(node): if arity == 1: # Unary operator result = op + traverse(node[0]) else: # N-ary operator node_pred = pred(node) result = '' sep = ' ' + op + ' ' e = [] for child in node: exp = traverse(child) # Check if there is a precedence conflict # If so, add parentheses child_pred = pred(child) if child_pred < node_pred or \ (child_pred == node_pred and op != child.title()): exp = '(' + exp + ')' e.append(exp) # Check if a multiplication sign is nessecary if op == '*': left, right = node # Get the previous multiplcation element if the arity is # # greater than 2 if left.title() == '*': left = left[1] # a * b -> ab # a * 2 -> a * 2 # a * (b) -> a(b) # 2 * a -> 2a if (is_id(left) and (is_id(right) or e[1][0] == '(')) \ or (is_int(left) and is_id(right)): sep = '' result += sep.join(e) else: # Function result = op + '(' + ', '.join(map(traverse, node)) + ')' return result return traverse(root) def is_id(node): return isinstance(node, Leaf) and not node.title().isdigit() def is_int(node): return isinstance(node, Leaf) and node.title().isdigit()