from traverse import traverse_depth_first from node import Node all_parens = ('()', '[]', '||', '{}') OPERATORS = ( ('left', ('vv', )), ('left', ('^^', )), ('left', ('=', )), ('left', ('+', '-')), ('nonassoc', ('int', 'd/d')), ('left', ('*', 'mod')), ('left', ('/', )), ('nonassoc', ('\'', )), ('nonassoc', ('neg', )), ('nonassoc', ('function', )), ('right', ('^', '_')), ('nonassoc', all_parens), ) assocs = {} preds = {} for i, (assoc, ops) in enumerate(OPERATORS): for op in ops: assocs[op] = assoc preds[op] = i NEG_PRED = preds['neg'] FUNC_PRED = preds['function'] MAX_PRED = len(OPERATORS) def is_function(node): """ Check if a given node is a function. A node is considered a function if it is not a leaf, and is not known in the precedences list above. """ return not node.is_leaf and node.title() not in preds def pred(node): """ Get the operator precedence of a node. Leaf nodes have the highest precedence for practical reasons. """ # Check known operators if not node.is_leaf: op = node.title() if node.is_negation(): if node[0].title() in '*/': return preds['-'] return NEG_PRED #if node.is_postfix() and not node[0].is_leaf \ # and node[0].title() in all_parens: # return preds['()'] if op in preds: return preds[op] return FUNC_PRED # Unary operator and leaves have highest precedence return MAX_PRED def rightmost_node(node): if node.is_leaf or not len(node): return node return rightmost_node(node[-1]) def is_unary_prefix(node): """ Check if a node is a unary operator that is placed before the operand. """ return not node.is_leaf and len(node) == 1 and node.title() == '-' def is_left_assoc(op): return op in assocs and assocs[op] == 'left' def is_right_assoc(op): return op in assocs and assocs[op] == 'right' def is_id(node): return node.is_leaf and not node.title().isdigit() def is_int(node): return node.is_leaf and node.title().isdigit() def is_power(node): return not node.is_leaf and node.title() == '^' def preprocess_node(node): node = node.clone() node.preprocess_str_exp() if node.negated: node.negated -= 1 return Node('-', preprocess_node(node)) if not node.is_leaf: for i, child in enumerate(node): node[i] = preprocess_node(child) if node.title() == '+' and node[1].is_negation(): return Node('-', node[0], node[1][0]) return node def generate_line(root): """ Print an expression tree in a single text line. Where needed, add parentheses. >>> from node import Node, Leaf >>> l0, l1 = Leaf(1), Leaf(2) >>> print generate_line(l0) 1 >>> plus = Node('+', l0, l1) >>> print generate_line(plus) 1 + 2 >>> plus2 = Node('+', l0, l1) >>> times = Node('*', plus, plus2) >>> print generate_line(times) (1 + 2)(1 + 2) >>> l2 = Leaf(3) >>> uminus = Node('-', l2) >>> times = Node('*', plus, uminus) >>> print generate_line(times) (1 + 2) * -3 >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x')))) >>> print generate_line(minus) 2 - -15x >>> left = Node('/', Leaf(22), Leaf(77)) >>> right = Node('/', Leaf(28), Leaf(77)) >>> minus = Node('-', left, right) >>> print generate_line(minus) 22 / 77 - 28 / 77 >>> plus = Node('+', left, Node('-', right)) >>> print generate_line(plus) 22 / 77 - 28 / 77 """ if not root: return '' if root.is_leaf: return str(root) content = {} def mult_sign(left, right, lparens, rparens): # Get the previous multiplication element in an nary multiplication if left.title() == '*': left = rightmost_node(left) # a * b -> ab # a * 2 -> a * 2 # a * (b) -> a(b) # (a) * b -> (a)b # (a) * (b) -> (a)(b) # 2 * a -> 2a # a * sin(b) -> a sin(b) left_char = content[left][-1] right_char = content[right][0] left_paren = lparens or left_char in ')]}' right_paren = rparens or right_char in '([{' right_alpha = right_char.isalpha() left_simple = is_id(left) or is_int(left) if left_paren or (right_paren and left_simple) \ or (is_id(left) and is_id(right)) \ or (is_int(left) and right_alpha): return '' if is_id(left) and right_alpha: return ' ' return ' * ' def construct_unary(node): op = node.title() value = node[0] strval = content[value] if op in ('()', '[]', '||', '{}'): return op[0] + strval + op[1] parens = False if pred(value) < pred(node): parens = not value.is_negation() elif pred(value) == pred(node): parens = len(value) > 1 if parens: strval = '(' + strval + ')' if node.is_postfix(): return strval + node.operator() prefix = node.operator() if prefix != '-' and not (strval[0] in '([|{' and is_function(node)): prefix += ' ' return prefix + strval def construct_binary(node): op = node.title() op_pred = pred(node) if node.no_spacing: sep = node.operator() else: sep = ' ' + node.operator() + ' ' left, right = node lstr = content[left] rstr = content[right] lpred = pred(left) rpred = pred(right) lparens = rparens = False unary_right = is_unary_prefix(right) if lpred < op_pred or (op in '*/' and left.is_negation()): lparens = True elif lpred == op_pred: lparens = is_right_assoc(left.title()) or is_right_assoc(op) if rpred < op_pred: rparens = not unary_right \ or (pred(right[0]) < op_pred and len(right[0]) > 1) elif rpred == op_pred and len(right) > 1: if right.title() == op: rparens = not is_right_assoc(op) elif is_left_assoc(right.title()): rparens = True # Check if multiplication sign is necessary if op == '*' and not unary_right: sep = mult_sign(left, right, lparens, rparens) if lparens: lstr = '(' + lstr + ')' if rparens: rstr = '(' + rstr + ')' return lstr + sep + rstr def construct_nary_mult(node): op_pred = pred(node) lstr = content[node[0]] lparens = pred(node[0]) < op_pred or node[0].is_negation() if lparens: lstr = '(' + lstr + ')' for i, right in enumerate(node[1:]): rparens = pred(right) < op_pred rstr = content[right] if rparens: rstr = '(' + rstr + ')' sign = mult_sign(node[i], right, lparens, rparens) lstr += sign + rstr lparens = rparens return lstr def construct_nary(node): if node.title() == '*': return construct_nary_mult(node) op_pred = pred(node) e = [] for child in node: exp = content[child] if pred(child) < op_pred: exp = '(' + exp + ')' e.append(exp) return (' ' + node.operator() + ' ').join(e) def construct_function(node): children = [content[child] for child in node] return '%s(%s)' % (node.operator(), ', '.join(children)) # Convert negations to unary nodes to be able to account for operator # precedence root = preprocess_node(root.clone()) # Traverse the expression tree and construct the mathematical expression in # the leafs and nodes in depth first order. for node in traverse_depth_first(root): custom = node.custom_line() if custom is not None: content[node] = custom continue if node.is_leaf: nodestr = str(node.value) else: arity = node.arity() if arity == 1: nodestr = construct_unary(node) elif is_function(node): nodestr = construct_function(node) elif arity == 2: nodestr = construct_binary(node) else: nodestr = construct_nary(node) content[node] = node.postprocess_str(nodestr) # Merge binary plus and unary minus signs into a binary minus return content[root].replace('+ -', '- ')