line.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from traverse import traverse_depth_first
  2. OPERATORS = [
  3. ('+', '-'),
  4. ('*', '/', 'mod'),
  5. ('^', )
  6. ]
  7. MAX_PRED = len(OPERATORS)
  8. def is_operator(node):
  9. """
  10. Check if a given node is an operator (otherwise, it's a function).
  11. """
  12. label = node.title()
  13. return any(map(lambda x: label in x, OPERATORS))
  14. def pred(node):
  15. """
  16. Get the precedence of an operator node.
  17. """
  18. # Check binary and n-ary operators
  19. if not node.is_leaf and len(node) > 1:
  20. op = node.title()
  21. for i, group in enumerate(OPERATORS):
  22. if op in group:
  23. return i
  24. # Unary operator and leaves have highest precedence
  25. return MAX_PRED
  26. def is_id(node):
  27. return node.is_leaf and not node.title().isdigit()
  28. def is_int(node):
  29. return node.is_leaf and node.title().isdigit()
  30. def is_power(node):
  31. return not node.is_leaf and node.title() == '^'
  32. def generate_line(root):
  33. """
  34. Print an expression tree in a single text line. Where needed, add
  35. parentheses.
  36. >>> from node import Node, Leaf
  37. >>> l0, l1 = Leaf(1), Leaf(2)
  38. >>> print generate_line(l0)
  39. 1
  40. >>> plus = Node('+', l0, l1)
  41. >>> print generate_line(plus)
  42. 1 + 2
  43. >>> plus2 = Node('+', l0, l1)
  44. >>> times = Node('*', plus, plus2)
  45. >>> print generate_line(times)
  46. (1 + 2)(1 + 2)
  47. >>> l2 = Leaf(3)
  48. >>> uminus = Node('-', l2)
  49. >>> times = Node('*', plus, uminus)
  50. >>> print generate_line(times)
  51. (1 + 2) * -3
  52. >>> exp = Leaf('x')
  53. >>> inf = Leaf('oo')
  54. >>> minus_inf = Node('-', inf)
  55. >>> integral = Node('int', exp, minus_inf, inf)
  56. >>> print generate_line(integral)
  57. int(x, -oo, oo)
  58. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  59. >>> print generate_line(minus)
  60. 2 - -15x
  61. >>> left = Node('/', Leaf(22), Leaf(77))
  62. >>> right = Node('/', Leaf(28), Leaf(77))
  63. >>> minus = Node('-', left, right)
  64. >>> print generate_line(minus)
  65. 22 / 77 - 28 / 77
  66. >>> plus = Node('+', left, Node('-', right))
  67. >>> print generate_line(plus)
  68. 22 / 77 - 28 / 77
  69. """
  70. if not root:
  71. return '<empty expression>'
  72. if root.is_leaf:
  73. return '-' * root.negated + root.title()
  74. content = {}
  75. def construct_unary(node):
  76. op = node.title()
  77. value = node[0]
  78. # -a
  79. # -3 * 4
  80. # --a
  81. if value.is_leaf \
  82. or not ' ' in content[value] or pred(value) > 0:
  83. return op + content[value]
  84. # -(a + b)
  85. return '%s(%s)' % (op, content[value])
  86. def construct_nary(node):
  87. op = node.title()
  88. # N-ary operator
  89. node_pred = pred(node)
  90. sep = ' ' + op + ' '
  91. e = []
  92. for i, child in enumerate(node):
  93. exp = content[child]
  94. #if i and op == '+' and exp[:2] == '-(':
  95. # exp = '-' + exp[2:-1]
  96. # print 'exp:', exp
  97. # Check if there is a precedence conflict
  98. # If so, add parentheses
  99. child_pred = pred(child)
  100. if child.negated:
  101. # (-a) ^ b
  102. if op == '^' and not i:
  103. exp = '(' + exp + ')'
  104. elif child_pred < node_pred:
  105. exp = '(' + exp + ')'
  106. elif child_pred == node_pred:
  107. if i and (op != child.title() \
  108. or (op == '+' and child[1].negated)):
  109. exp = '(' + exp + ')'
  110. elif not i and op == '^':
  111. exp = '(' + exp + ')'
  112. e.append(exp)
  113. if op == '*':
  114. # Check if an explicit multiplication sign is nessecary
  115. left, right = node
  116. # Get the previous multiplication element if the arity is
  117. # greater than 2
  118. if left.title() == '*':
  119. left = left[1]
  120. # a * b -> ab
  121. # a * 2 -> a * 2
  122. # a * (b) -> a(b)
  123. # (a) * b -> (a)b
  124. # (a) * (b) -> (a)(b)
  125. # 2 * a -> 2a
  126. left_paren = e[0][-1] == ')'
  127. right_paren = e[1][0] == '('
  128. left_suited = left_paren or is_id(left) or is_int(left)
  129. right_suited = right_paren
  130. if not right_suited and not right.negated:
  131. right_suited = is_id(right) or is_power(right)
  132. if left_suited and right_suited:
  133. sep = ''
  134. exp = sep.join(e)
  135. if node.negated and op not in ('*', '/', '^'):
  136. exp = '(' + exp + ')'
  137. return exp
  138. def construct_function(node):
  139. buf = []
  140. for child in node:
  141. buf.append(content[child])
  142. return '%s(%s)' % (node.title(), ', '.join(buf))
  143. # Traverse the expression tree and construct the mathematical expression in
  144. # the leafs and nodes in depth first order.
  145. for node in traverse_depth_first(root):
  146. if node.is_leaf:
  147. content[node] = node.title()
  148. else:
  149. arity = len(node)
  150. if is_operator(node):
  151. if arity == 1:
  152. content[node] = construct_unary(node)
  153. else:
  154. content[node] = construct_nary(node)
  155. else:
  156. content[node] = construct_function(node)
  157. # Add negations
  158. content[node] = '-' * node.negated + content[node]
  159. # Merge binary plus and unary minus signs into binary minus.
  160. return content[root].replace('+ -', '- ')