line.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from traverse import traverse_depth_first
  2. OPERATORS = [
  3. ('+', '-'),
  4. ('*', '/', 'mod'),
  5. ('^', )
  6. ]
  7. MAX_PRED = len(OPERATORS)
  8. def is_operator(node):
  9. """
  10. Check if a given node is an operator (otherwise, it's a function).
  11. """
  12. label = node.title()
  13. return any(map(lambda x: label in x, OPERATORS))
  14. def pred(node):
  15. """
  16. Get the precedence of an operator node.
  17. """
  18. # Check binary and n-ary operators
  19. if not node.is_leaf and len(node) > 1:
  20. op = node.title()
  21. for i, group in enumerate(OPERATORS):
  22. if op in group:
  23. return i
  24. # Unary operator and leaves have highest precedence
  25. return MAX_PRED
  26. def generate_line_old(root):
  27. """
  28. Print an expression tree in a single text line. Where needed, add
  29. parentheses.
  30. >>> from node import Node, Leaf
  31. >>> l0, l1 = Leaf(1), Leaf(2)
  32. >>> plus = Node('+', l0, l1)
  33. >>> print generate_line_old(plus)
  34. 1 + 2
  35. >>> plus2 = Node('+', l0, l1)
  36. >>> times = Node('*', plus, plus2)
  37. >>> print generate_line_old(times)
  38. (1 + 2)(1 + 2)
  39. >>> l2 = Leaf(3)
  40. >>> uminus = Node('-', l2)
  41. >>> times = Node('*', plus, uminus)
  42. >>> print generate_line_old(times)
  43. (1 + 2) * -3
  44. >>> exp = Leaf('x')
  45. >>> inf = Leaf('oo')
  46. >>> minus_inf = Node('-', inf)
  47. >>> integral = Node('int', exp, minus_inf, inf)
  48. >>> print generate_line_old(integral)
  49. int(x, -oo, oo)
  50. """
  51. def traverse(node):
  52. """
  53. The expression tree is traversed using preorder traversal:
  54. 1. Visit the root
  55. 2. Traverse the subtrees in left-to-right order
  56. """
  57. if not node:
  58. return '<empty expression>'
  59. op = node.title()
  60. if not node.nodes:
  61. return op
  62. arity = len(node)
  63. if is_operator(node):
  64. if arity == 1:
  65. # Unary operator
  66. sub = node[0]
  67. sub_exp = traverse(sub)
  68. # Negated sub-expressions with spaces in them should be
  69. # enclosed in parentheses, unless they have a higher precedence
  70. # than subtraction are rewritten to a factor of a subtraction:
  71. # -(1 + 2)
  72. # -(1 - 2)
  73. # -4a
  74. # -(4 * 5)
  75. # 1 - 4 * 5
  76. # 1 + -(4 * 5) -> 1 - 4 * 5
  77. if ' ' in sub_exp and sub.is_Leaf \
  78. and hasattr(node, 'marked_negation') \
  79. and pred(sub) > 0:
  80. sub_exp = '(' + sub_exp + ')'
  81. result = op + sub_exp
  82. else:
  83. # N-ary operator
  84. node_pred = pred(node)
  85. result = ''
  86. sep = ' ' + op + ' '
  87. e = []
  88. # Mark added and subtracted negations for later use when adding
  89. # parentheses
  90. if op in ('+', '-'):
  91. for child in node:
  92. if child.title() == '-' and len(child) == 1:
  93. child.marked_negation = True
  94. for i, child in enumerate(node):
  95. exp = traverse(child)
  96. # Check if there is a precedence conflict
  97. # If so, add parentheses
  98. child_pred = pred(child)
  99. if child_pred < node_pred or \
  100. (i and child_pred == node_pred \
  101. and op != child.title()):
  102. exp = '(' + exp + ')'
  103. e.append(exp)
  104. if op == '*':
  105. # Check if an explicit multiplication sign is nessecary
  106. left, right = node
  107. # Get the previous multiplication element if the arity is
  108. # greater than 2
  109. if left.title() == '*':
  110. left = left[1]
  111. # a * b -> ab
  112. # a * 2 -> a * 2
  113. # a * (b) -> a(b)
  114. # (a) * b -> (a)b
  115. # (a) * (b) -> (a)(b)
  116. # 2 * a -> 2a
  117. left_id = is_id(left)
  118. right_id = is_id(right)
  119. left_paren = e[0][-1] == ')'
  120. right_paren = e[1][0] == '('
  121. left_int = is_int(left)
  122. if (left_id or left_paren or left_int) \
  123. and (right_id or right_paren):
  124. sep = ''
  125. result += sep.join(e)
  126. else:
  127. # Function call
  128. result = op + '(' + ', '.join(map(traverse, node)) + ')'
  129. return result
  130. # An addition with negation can be written as a subtraction, e.g.:
  131. # 1 + -2 -> 1 - 2
  132. return traverse(root).replace('+ -', '- ')
  133. def is_id(node):
  134. return node.is_leaf and not node.title().isdigit()
  135. def is_int(node):
  136. return node.is_leaf and node.title().isdigit()
  137. def generate_line(root):
  138. """
  139. Print an expression tree in a single text line. Where needed, add
  140. parentheses.
  141. >>> from node import Node, Leaf
  142. >>> l0, l1 = Leaf(1), Leaf(2)
  143. >>> print generate_line(l0)
  144. 1
  145. >>> plus = Node('+', l0, l1)
  146. >>> print generate_line(plus)
  147. 1 + 2
  148. >>> plus2 = Node('+', l0, l1)
  149. >>> times = Node('*', plus, plus2)
  150. >>> print generate_line(times)
  151. (1 + 2)(1 + 2)
  152. >>> l2 = Leaf(3)
  153. >>> uminus = Node('-', l2)
  154. >>> times = Node('*', plus, uminus)
  155. >>> print generate_line(times)
  156. (1 + 2) * -3
  157. >>> exp = Leaf('x')
  158. >>> inf = Leaf('oo')
  159. >>> minus_inf = Node('-', inf)
  160. >>> integral = Node('int', exp, minus_inf, inf)
  161. >>> print generate_line(integral)
  162. int(x, -oo, oo)
  163. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  164. >>> print generate_line(minus)
  165. 2 - -15x
  166. >>> left = Node('/', Leaf(22), Leaf(77))
  167. >>> right = Node('/', Leaf(28), Leaf(77))
  168. >>> minus = Node('-', left, right)
  169. >>> print generate_line(minus)
  170. 22 / 77 - 28 / 77
  171. >>> plus = Node('+', left, Node('-', right))
  172. >>> print generate_line(plus)
  173. 22 / 77 - 28 / 77
  174. """
  175. if not root:
  176. return '<empty expression>'
  177. if root.is_leaf:
  178. return '-' * root.negated + root.title()
  179. content = {}
  180. def construct_unary(node):
  181. result = node.title()
  182. value = node[0]
  183. # -a
  184. # -3*4
  185. # --a
  186. if value.is_leaf \
  187. or not ' ' in content[value] or pred(value) > 0:
  188. return result + content[value]
  189. return '%s(%s)' % (result, content[value])
  190. def construct_nary(node):
  191. op = node.title()
  192. # N-ary operator
  193. node_pred = pred(node)
  194. sep = ' ' + op + ' '
  195. e = []
  196. for i, child in enumerate(node):
  197. exp = content[child]
  198. # Check if there is a precedence conflict
  199. # If so, add parentheses
  200. child_pred = pred(child)
  201. if child_pred < node_pred or \
  202. (i and child_pred == node_pred \
  203. and op != child.title()):
  204. exp = '(' + exp + ')'
  205. e.append(exp)
  206. if op == '*':
  207. # Check if an explicit multiplication sign is nessecary
  208. left, right = node
  209. # Get the previous multiplication element if the arity is
  210. # greater than 2
  211. if left.title() == '*':
  212. left = left[1]
  213. # a * b -> ab
  214. # a * 2 -> a * 2
  215. # a * (b) -> a(b)
  216. # (a) * b -> (a)b
  217. # (a) * (b) -> (a)(b)
  218. # 2 * a -> 2a
  219. left_paren = e[0][-1] == ')'
  220. right_paren = e[1][0] == '('
  221. if (is_id(left) or left_paren or is_int(left)) \
  222. and ((not right.negated and is_id(right)) or right_paren):
  223. sep = ''
  224. exp = sep.join(e)
  225. #if node.negated:
  226. # FIXME: Keep it this way?
  227. if node.negated and op != '*':
  228. exp = '(' + exp + ')'
  229. return exp
  230. def construct_function(node):
  231. buf = []
  232. for child in node:
  233. buf.append(content[child])
  234. return '%s(%s)' % (node.title(), ', '.join(buf))
  235. # Traverse the expression tree and construct the mathematical expression in
  236. # the leafs and nodes in depth first order.
  237. for node in traverse_depth_first(root):
  238. if node.is_leaf:
  239. content[node] = node.title()
  240. else:
  241. arity = len(node)
  242. if is_operator(node):
  243. if arity == 1:
  244. content[node] = construct_unary(node)
  245. else:
  246. content[node] = construct_nary(node)
  247. else:
  248. content[node] = construct_function(node)
  249. # Add negations
  250. content[node] = '-' * node.negated + content[node]
  251. # Merge binary plus and unary minus signs into binary minus.
  252. return content[root].replace('+ -', '- ')