line.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from traverse import traverse_depth_first
  2. OPERATORS = [
  3. ('+', '-'),
  4. ('*', '/', 'mod'),
  5. ('^', )
  6. ]
  7. MAX_PRED = len(OPERATORS)
  8. def is_operator(node):
  9. """
  10. Check if a given node is an operator (otherwise, it's a function).
  11. """
  12. label = node.title()
  13. return any(map(lambda x: label in x, OPERATORS))
  14. def pred(node):
  15. """
  16. Get the precedence of an operator node.
  17. """
  18. # Check binary and n-ary operators
  19. if not node.is_leaf and len(node) > 1:
  20. op = node.title()
  21. for i, group in enumerate(OPERATORS):
  22. if op in group:
  23. return i
  24. # Unary operator and leaves have highest precedence
  25. return MAX_PRED
  26. def is_id(node):
  27. return node.is_leaf and not node.title().isdigit()
  28. def is_power(node):
  29. return not node.is_leaf and node.title() == '^'
  30. def generate_line(root):
  31. """
  32. Print an expression tree in a single text line. Where needed, add
  33. parentheses.
  34. >>> from node import Node, Leaf
  35. >>> l0, l1 = Leaf(1), Leaf(2)
  36. >>> print generate_line(l0)
  37. 1
  38. >>> plus = Node('+', l0, l1)
  39. >>> print generate_line(plus)
  40. 1 + 2
  41. >>> plus2 = Node('+', l0, l1)
  42. >>> times = Node('*', plus, plus2)
  43. >>> print generate_line(times)
  44. (1 + 2)(1 + 2)
  45. >>> l2 = Leaf(3)
  46. >>> uminus = Node('-', l2)
  47. >>> times = Node('*', plus, uminus)
  48. >>> print generate_line(times)
  49. (1 + 2) * -3
  50. >>> exp = Leaf('x')
  51. >>> inf = Leaf('oo')
  52. >>> minus_inf = Node('-', inf)
  53. >>> integral = Node('int', exp, minus_inf, inf)
  54. >>> print generate_line(integral)
  55. int(x, -oo, oo)
  56. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  57. >>> print generate_line(minus)
  58. 2 - -15x
  59. >>> left = Node('/', Leaf(22), Leaf(77))
  60. >>> right = Node('/', Leaf(28), Leaf(77))
  61. >>> minus = Node('-', left, right)
  62. >>> print generate_line(minus)
  63. 22 / 77 - 28 / 77
  64. >>> plus = Node('+', left, Node('-', right))
  65. >>> print generate_line(plus)
  66. 22 / 77 - 28 / 77
  67. """
  68. if not root:
  69. return '<empty expression>'
  70. if root.is_leaf:
  71. return str(root)
  72. content = {}
  73. def construct_unary(node):
  74. op = node.title()
  75. value = node[0]
  76. # -a
  77. # -3 * 4
  78. # --a
  79. if value.is_leaf \
  80. or not ' ' in content[value] or pred(value) > 0:
  81. return op + content[value]
  82. # -(a + b)
  83. return '%s(%s)' % (op, content[value])
  84. def construct_nary(node):
  85. op = node.title()
  86. # N-ary operator
  87. node_pred = pred(node)
  88. sep = ' ' + op + ' '
  89. e = []
  90. for i, child in enumerate(node):
  91. exp = content[child]
  92. #if i and op == '+' and exp[:2] == '-(':
  93. # exp = '-' + exp[2:-1]
  94. # print 'exp:', exp
  95. # Check if there is a precedence conflict
  96. # If so, add parentheses
  97. child_pred = pred(child)
  98. if child.negated:
  99. # (-a) ^ b
  100. if op == '^' and not i:
  101. exp = '(' + exp + ')'
  102. elif child_pred < node_pred:
  103. exp = '(' + exp + ')'
  104. elif child_pred == node_pred:
  105. if i and (op != child.title() \
  106. or (op == '+' and child[1].negated)):
  107. exp = '(' + exp + ')'
  108. elif not i and op == '^':
  109. exp = '(' + exp + ')'
  110. e.append(exp)
  111. if op == '*':
  112. # Check if an explicit multiplication sign is nessecary
  113. left, right = node
  114. # Get the previous multiplication element if the arity is
  115. # greater than 2
  116. if left.title() == '*':
  117. left = left[1]
  118. # a * b -> ab
  119. # a * 2 -> a * 2
  120. # a * (b) -> a(b)
  121. # (a) * b -> (a)b
  122. # (a) * (b) -> (a)(b)
  123. # 2 * a -> 2a
  124. left_end = e[0][-1]
  125. right_begin = e[1][0]
  126. left_suited = left_end == ')' or left_end.isdigit() or is_id(left)
  127. right_suited = right_begin == '('
  128. if not right_suited and not right.negated:
  129. right_suited = is_id(right) or is_power(right)
  130. if left_suited and right_suited:
  131. sep = ''
  132. exp = sep.join(e)
  133. if node.negated and op not in ('*', '/', '^'):
  134. exp = '(' + exp + ')'
  135. return exp
  136. def construct_function(node):
  137. buf = []
  138. for child in node:
  139. buf.append(content[child])
  140. return '%s(%s)' % (node.title(), ', '.join(buf))
  141. # Traverse the expression tree and construct the mathematical expression in
  142. # the leafs and nodes in depth first order.
  143. for node in traverse_depth_first(root):
  144. if node.is_leaf:
  145. content[node] = str(node)
  146. else:
  147. arity = len(node)
  148. if is_operator(node):
  149. if arity == 1:
  150. content[node] = construct_unary(node)
  151. else:
  152. content[node] = construct_nary(node)
  153. else:
  154. content[node] = construct_function(node)
  155. # Add negations
  156. content[node] = '-' * node.negated + content[node]
  157. # Merge binary plus and unary minus signs into binary minus.
  158. return content[root].replace('+ -', '- ')