| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- from traverse import traverse_depth_first
- from node import Node
- all_parens = ('()', '[]', '||', '{}')
- OPERATORS = (
- ('left', ('vv', )),
- ('left', ('^^', )),
- ('left', ('=', )),
- ('left', ('+', '-')),
- ('nonassoc', ('int', 'd/d')),
- ('left', ('*', 'mod')),
- ('left', ('/', )),
- ('nonassoc', ('\'', )),
- ('nonassoc', ('neg', )),
- ('nonassoc', ('function', )),
- ('right', ('^', '_')),
- ('nonassoc', all_parens),
- )
- assocs = {}
- preds = {}
- for i, (assoc, ops) in enumerate(OPERATORS):
- for op in ops:
- assocs[op] = assoc
- preds[op] = i
- NEG_PRED = preds['neg']
- FUNC_PRED = preds['function']
- MAX_PRED = len(OPERATORS)
- def is_function(node):
- """
- Check if a given node is a function. A node is considered a function if it
- is not a leaf, and is not known in the precedences list above.
- """
- return not node.is_leaf and node.title() not in preds
- def pred(node):
- """
- Get the operator precedence of a node. Leaf nodes have the highest
- precedence for practical reasons.
- """
- # Check known operators
- if not node.is_leaf:
- op = node.title()
- if node.is_negation():
- if node[0].title() in '*/':
- return preds['-']
- return NEG_PRED
- #if node.is_postfix() and not node[0].is_leaf \
- # and node[0].title() in all_parens:
- # return preds['()']
- if op in preds:
- return preds[op]
- return FUNC_PRED
- # Unary operator and leaves have highest precedence
- return MAX_PRED
- def rightmost_node(node):
- if node.is_leaf or not len(node):
- return node
- return rightmost_node(node[-1])
- def is_unary_prefix(node):
- """
- Check if a node is a unary operator that is placed before the operand.
- """
- return not node.is_leaf and len(node) == 1 and node.title() == '-'
- def is_left_assoc(op):
- return op in assocs and assocs[op] == 'left'
- def is_right_assoc(op):
- return op in assocs and assocs[op] == 'right'
- def is_id(node):
- return node.is_leaf and not node.title().isdigit()
- def is_int(node):
- return node.is_leaf and node.title().isdigit()
- def is_power(node):
- return not node.is_leaf and node.title() == '^'
- def preprocess_node(node):
- node = node.clone()
- node.preprocess_str_exp()
- if node.negated:
- node.negated -= 1
- return Node('-', preprocess_node(node))
- if not node.is_leaf:
- for i, child in enumerate(node):
- node[i] = preprocess_node(child)
- if node.title() == '+' and node[1].is_negation():
- return Node('-', node[0], node[1][0])
- return node
- def generate_line(root):
- """
- Print an expression tree in a single text line. Where needed, add
- parentheses.
- >>> from node import Node, Leaf
- >>> l0, l1 = Leaf(1), Leaf(2)
- >>> print generate_line(l0)
- 1
- >>> plus = Node('+', l0, l1)
- >>> print generate_line(plus)
- 1 + 2
- >>> plus2 = Node('+', l0, l1)
- >>> times = Node('*', plus, plus2)
- >>> print generate_line(times)
- (1 + 2)(1 + 2)
- >>> l2 = Leaf(3)
- >>> uminus = Node('-', l2)
- >>> times = Node('*', plus, uminus)
- >>> print generate_line(times)
- (1 + 2) * -3
- >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
- >>> print generate_line(minus)
- 2 - -15x
- >>> left = Node('/', Leaf(22), Leaf(77))
- >>> right = Node('/', Leaf(28), Leaf(77))
- >>> minus = Node('-', left, right)
- >>> print generate_line(minus)
- 22 / 77 - 28 / 77
- >>> plus = Node('+', left, Node('-', right))
- >>> print generate_line(plus)
- 22 / 77 - 28 / 77
- """
- if not root:
- return '<empty expression>'
- if root.is_leaf:
- return str(root)
- content = {}
- def mult_sign(left, right, lparens, rparens):
- # Get the previous multiplication element in an nary multiplication
- if left.title() == '*':
- left = rightmost_node(left)
- # a * b -> ab
- # a * 2 -> a * 2
- # a * (b) -> a(b)
- # (a) * b -> (a)b
- # (a) * (b) -> (a)(b)
- # 2 * a -> 2a
- # a * sin(b) -> a sin(b)
- left_char = content[left][-1]
- right_char = content[right][0]
- left_paren = lparens or left_char in ')]}'
- right_paren = rparens or right_char in '([{'
- right_alpha = right_char.isalpha()
- left_simple = is_id(left) or is_int(left)
- if left_paren or (right_paren and left_simple) \
- or (is_id(left) and is_id(right)) \
- or (is_int(left) and right_alpha):
- return ''
- if is_id(left) and right_alpha:
- return ' '
- return ' * '
- def construct_unary(node):
- op = node.title()
- value = node[0]
- strval = content[value]
- if op in ('()', '[]', '||', '{}'):
- return op[0] + strval + op[1]
- parens = False
- if pred(value) < pred(node):
- parens = not value.is_negation()
- elif pred(value) == pred(node):
- parens = len(value) > 1
- if parens:
- strval = '(' + strval + ')'
- if node.is_postfix():
- return strval + node.operator()
- prefix = node.operator()
- if prefix != '-' and not (strval[0] in '([|{' and is_function(node)):
- prefix += ' '
- return prefix + strval
- def construct_binary(node):
- op = node.title()
- op_pred = pred(node)
- if node.no_spacing:
- sep = node.operator()
- else:
- sep = ' ' + node.operator() + ' '
- left, right = node
- lstr = content[left]
- rstr = content[right]
- lpred = pred(left)
- rpred = pred(right)
- lparens = rparens = False
- unary_right = is_unary_prefix(right)
- if lpred < op_pred or (op in '*/' and left.is_negation()):
- lparens = True
- elif lpred == op_pred:
- lparens = is_right_assoc(left.title()) or is_right_assoc(op)
- if rpred < op_pred:
- rparens = not unary_right
- elif rpred == op_pred and len(right) > 1:
- if right.title() == op:
- rparens = not is_right_assoc(op)
- elif is_left_assoc(right.title()):
- rparens = True
- # Check if multiplication sig is necessary
- if op == '*' and not unary_right:
- sep = mult_sign(left, right, lparens, rparens)
- if lparens:
- lstr = '(' + lstr + ')'
- if rparens:
- rstr = '(' + rstr + ')'
- return lstr + sep + rstr
- def construct_nary_mult(node):
- op_pred = pred(node)
- lstr = content[node[0]]
- lparens = pred(node[0]) < op_pred or node[0].is_negation()
- if lparens:
- lstr = '(' + lstr + ')'
- for i, right in enumerate(node[1:]):
- rparens = pred(right) < op_pred
- rstr = content[right]
- if rparens:
- rstr = '(' + rstr + ')'
- sign = mult_sign(node[i], right, lparens, rparens)
- lstr += sign + rstr
- lparens = rparens
- return lstr
- def construct_nary(node):
- if node.title() == '*':
- return construct_nary_mult(node)
- op_pred = pred(node)
- e = []
- for child in node:
- exp = content[child]
- if pred(child) < op_pred:
- exp = '(' + exp + ')'
- e.append(exp)
- return (' ' + node.operator() + ' ').join(e)
- def construct_function(node):
- children = [content[child] for child in node]
- return '%s(%s)' % (node.operator(), ', '.join(children))
- # Convert negations to unary nodes to be able to account for operator
- # precedence
- root = preprocess_node(root.clone())
- # Traverse the expression tree and construct the mathematical expression in
- # the leafs and nodes in depth first order.
- for node in traverse_depth_first(root):
- custom = node.custom_line()
- if custom is not None:
- content[node] = custom
- continue
- if node.is_leaf:
- nodestr = str(node.value)
- else:
- arity = node.arity()
- if arity == 1:
- nodestr = construct_unary(node)
- elif is_function(node):
- nodestr = construct_function(node)
- elif arity == 2:
- nodestr = construct_binary(node)
- else:
- nodestr = construct_nary(node)
- content[node] = node.postprocess_str(nodestr)
- # Merge binary plus and unary minus signs into a binary minus
- return content[root].replace('+ -', '- ')
|