line.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from traverse import traverse_depth_first
  2. from node import Node
  3. all_parens = ('()', '[]', '||', '{}')
  4. OPERATORS = (
  5. ('left', ('vv', )),
  6. ('left', ('^^', )),
  7. ('left', ('=', )),
  8. ('left', ('+', '-')),
  9. ('nonassoc', ('int', 'd/d')),
  10. ('left', ('*', 'mod')),
  11. ('left', ('/', )),
  12. ('nonassoc', ('\'', )),
  13. ('nonassoc', ('neg', )),
  14. ('nonassoc', ('function', )),
  15. ('right', ('^', '_')),
  16. ('nonassoc', all_parens),
  17. )
  18. assocs = {}
  19. preds = {}
  20. for i, (assoc, ops) in enumerate(OPERATORS):
  21. for op in ops:
  22. assocs[op] = assoc
  23. preds[op] = i
  24. NEG_PRED = preds['neg']
  25. FUNC_PRED = preds['function']
  26. MAX_PRED = len(OPERATORS)
  27. def is_function(node):
  28. """
  29. Check if a given node is a function. A node is considered a function if it
  30. is not a leaf, and is not known in the precedences list above.
  31. """
  32. return not node.is_leaf and node.title() not in preds
  33. def pred(node):
  34. """
  35. Get the operator precedence of a node. Leaf nodes have the highest
  36. precedence for practical reasons.
  37. """
  38. # Check known operators
  39. if not node.is_leaf:
  40. op = node.title()
  41. if node.is_negation():
  42. if node[0].title() in '*/':
  43. return preds['-']
  44. return NEG_PRED
  45. #if node.is_postfix() and not node[0].is_leaf \
  46. # and node[0].title() in all_parens:
  47. # return preds['()']
  48. if op in preds:
  49. return preds[op]
  50. return FUNC_PRED
  51. # Unary operator and leaves have highest precedence
  52. return MAX_PRED
  53. def rightmost_node(node):
  54. if node.is_leaf or not len(node):
  55. return node
  56. return rightmost_node(node[-1])
  57. def is_unary_prefix(node):
  58. """
  59. Check if a node is a unary operator that is placed before the operand.
  60. """
  61. return not node.is_leaf and len(node) == 1 and node.title() == '-'
  62. def is_left_assoc(op):
  63. return op in assocs and assocs[op] == 'left'
  64. def is_right_assoc(op):
  65. return op in assocs and assocs[op] == 'right'
  66. def is_id(node):
  67. return node.is_leaf and not node.title().isdigit()
  68. def is_int(node):
  69. return node.is_leaf and node.title().isdigit()
  70. def is_power(node):
  71. return not node.is_leaf and node.title() == '^'
  72. def preprocess_node(node):
  73. node = node.clone()
  74. node.preprocess_str_exp()
  75. if node.negated:
  76. node.negated -= 1
  77. return Node('-', preprocess_node(node))
  78. if not node.is_leaf:
  79. for i, child in enumerate(node):
  80. node[i] = preprocess_node(child)
  81. if node.title() == '+' and node[1].is_negation():
  82. return Node('-', node[0], node[1][0])
  83. return node
  84. def generate_line(root):
  85. """
  86. Print an expression tree in a single text line. Where needed, add
  87. parentheses.
  88. >>> from node import Node, Leaf
  89. >>> l0, l1 = Leaf(1), Leaf(2)
  90. >>> print generate_line(l0)
  91. 1
  92. >>> plus = Node('+', l0, l1)
  93. >>> print generate_line(plus)
  94. 1 + 2
  95. >>> plus2 = Node('+', l0, l1)
  96. >>> times = Node('*', plus, plus2)
  97. >>> print generate_line(times)
  98. (1 + 2)(1 + 2)
  99. >>> l2 = Leaf(3)
  100. >>> uminus = Node('-', l2)
  101. >>> times = Node('*', plus, uminus)
  102. >>> print generate_line(times)
  103. (1 + 2) * -3
  104. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  105. >>> print generate_line(minus)
  106. 2 - -15x
  107. >>> left = Node('/', Leaf(22), Leaf(77))
  108. >>> right = Node('/', Leaf(28), Leaf(77))
  109. >>> minus = Node('-', left, right)
  110. >>> print generate_line(minus)
  111. 22 / 77 - 28 / 77
  112. >>> plus = Node('+', left, Node('-', right))
  113. >>> print generate_line(plus)
  114. 22 / 77 - 28 / 77
  115. """
  116. if not root:
  117. return '<empty expression>'
  118. if root.is_leaf:
  119. return str(root)
  120. content = {}
  121. def mult_sign(left, right, lparens, rparens):
  122. # Get the previous multiplication element in an nary multiplication
  123. if left.title() == '*':
  124. left = rightmost_node(left)
  125. # a * b -> ab
  126. # a * 2 -> a * 2
  127. # a * (b) -> a(b)
  128. # (a) * b -> (a)b
  129. # (a) * (b) -> (a)(b)
  130. # 2 * a -> 2a
  131. # a * sin(b) -> a sin(b)
  132. left_char = content[left][-1]
  133. right_char = content[right][0]
  134. left_paren = lparens or left_char in ')]}'
  135. right_paren = rparens or right_char in '([{'
  136. right_alpha = right_char.isalpha()
  137. left_simple = is_id(left) or is_int(left)
  138. if left_paren or (right_paren and left_simple) \
  139. or (is_id(left) and is_id(right)) \
  140. or (is_int(left) and right_alpha):
  141. return ''
  142. if is_id(left) and right_alpha:
  143. return ' '
  144. return ' * '
  145. def construct_unary(node):
  146. op = node.title()
  147. value = node[0]
  148. strval = content[value]
  149. if op in ('()', '[]', '||', '{}'):
  150. return op[0] + strval + op[1]
  151. parens = False
  152. if pred(value) < pred(node):
  153. parens = not value.is_negation()
  154. elif pred(value) == pred(node):
  155. parens = len(value) > 1
  156. if parens:
  157. strval = '(' + strval + ')'
  158. if node.is_postfix():
  159. return strval + node.operator()
  160. prefix = node.operator()
  161. if prefix != '-' and not (strval[0] in '([|{' and is_function(node)):
  162. prefix += ' '
  163. return prefix + strval
  164. def construct_binary(node):
  165. op = node.title()
  166. op_pred = pred(node)
  167. if node.no_spacing:
  168. sep = node.operator()
  169. else:
  170. sep = ' ' + node.operator() + ' '
  171. left, right = node
  172. lstr = content[left]
  173. rstr = content[right]
  174. lpred = pred(left)
  175. rpred = pred(right)
  176. lparens = rparens = False
  177. unary_right = is_unary_prefix(right)
  178. if lpred < op_pred or (op in '*/' and left.is_negation()):
  179. lparens = True
  180. elif lpred == op_pred:
  181. lparens = is_right_assoc(left.title()) or is_right_assoc(op)
  182. if rpred < op_pred:
  183. rparens = not unary_right \
  184. or (op == '/' and right[0].title() in '*/')
  185. elif rpred == op_pred and len(right) > 1:
  186. if right.title() == op:
  187. rparens = not is_right_assoc(op)
  188. elif is_left_assoc(right.title()):
  189. rparens = True
  190. if lparens:
  191. lstr = '(' + lstr + ')'
  192. if rparens:
  193. rstr = '(' + rstr + ')'
  194. # Check if multiplication sign is necessary
  195. if op == '*' and not unary_right:
  196. sep = mult_sign(left, right, lparens, rparens)
  197. return lstr + sep + rstr
  198. def construct_nary_mult(node):
  199. op_pred = pred(node)
  200. lstr = content[node[0]]
  201. lparens = pred(node[0]) < op_pred or node[0].is_negation()
  202. if lparens:
  203. lstr = '(' + lstr + ')'
  204. for i, right in enumerate(node[1:]):
  205. rparens = pred(right) < op_pred
  206. rstr = content[right]
  207. if rparens:
  208. rstr = '(' + rstr + ')'
  209. sign = mult_sign(node[i], right, lparens, rparens)
  210. lstr += sign + rstr
  211. lparens = rparens
  212. return lstr
  213. def construct_nary(node):
  214. if node.title() == '*':
  215. return construct_nary_mult(node)
  216. op_pred = pred(node)
  217. e = []
  218. for child in node:
  219. exp = content[child]
  220. if pred(child) < op_pred:
  221. exp = '(' + exp + ')'
  222. e.append(exp)
  223. return (' ' + node.operator() + ' ').join(e)
  224. def construct_function(node):
  225. children = [content[child] for child in node]
  226. return '%s(%s)' % (node.operator(), ', '.join(children))
  227. # Convert negations to unary nodes to be able to account for operator
  228. # precedence
  229. root = preprocess_node(root.clone())
  230. # Traverse the expression tree and construct the mathematical expression in
  231. # the leafs and nodes in depth first order.
  232. for node in traverse_depth_first(root):
  233. custom = node.custom_line()
  234. if custom is not None:
  235. content[node] = custom
  236. continue
  237. if node.is_leaf:
  238. nodestr = str(node.value)
  239. else:
  240. arity = node.arity()
  241. if arity == 1:
  242. nodestr = construct_unary(node)
  243. elif is_function(node):
  244. nodestr = construct_function(node)
  245. elif arity == 2:
  246. nodestr = construct_binary(node)
  247. else:
  248. nodestr = construct_nary(node)
  249. content[node] = node.postprocess_str(nodestr)
  250. # Merge binary plus and unary minus signs into a binary minus
  251. return content[root].replace('+ -', '- ')