line.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from traverse import traverse_depth_first
  2. OPERATORS = [
  3. ('+', '-'),
  4. ('*', '/', 'mod'),
  5. ('^', )
  6. ]
  7. MAX_PRED = len(OPERATORS)
  8. def is_operator(node):
  9. """
  10. Check if a given node is an operator (otherwise, it's a function).
  11. """
  12. label = node.title()
  13. return any(map(lambda x: label in x, OPERATORS))
  14. def pred(node):
  15. """
  16. Get the precedence of an operator node.
  17. """
  18. # Check binary and n-ary operators
  19. if not node.is_leaf and len(node) > 1:
  20. op = node.title()
  21. for i, group in enumerate(OPERATORS):
  22. if op in group:
  23. return i
  24. # Unary operator and leaves have highest precedence
  25. return MAX_PRED
  26. def is_id(node):
  27. return node.is_leaf and not node.title().isdigit()
  28. def is_int(node):
  29. return node.is_leaf and node.title().isdigit()
  30. def is_power(node):
  31. return not node.is_leaf and node.title() == '^'
  32. def generate_line(root):
  33. """
  34. Print an expression tree in a single text line. Where needed, add
  35. parentheses.
  36. >>> from node import Node, Leaf
  37. >>> l0, l1 = Leaf(1), Leaf(2)
  38. >>> print generate_line(l0)
  39. 1
  40. >>> plus = Node('+', l0, l1)
  41. >>> print generate_line(plus)
  42. 1 + 2
  43. >>> plus2 = Node('+', l0, l1)
  44. >>> times = Node('*', plus, plus2)
  45. >>> print generate_line(times)
  46. (1 + 2)(1 + 2)
  47. >>> l2 = Leaf(3)
  48. >>> uminus = Node('-', l2)
  49. >>> times = Node('*', plus, uminus)
  50. >>> print generate_line(times)
  51. (1 + 2) * -3
  52. >>> exp = Leaf('x')
  53. >>> inf = Leaf('oo')
  54. >>> minus_inf = Node('-', inf)
  55. >>> integral = Node('int', exp, minus_inf, inf)
  56. >>> print generate_line(integral)
  57. int(x, -oo, oo)
  58. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  59. >>> print generate_line(minus)
  60. 2 - -15x
  61. >>> left = Node('/', Leaf(22), Leaf(77))
  62. >>> right = Node('/', Leaf(28), Leaf(77))
  63. >>> minus = Node('-', left, right)
  64. >>> print generate_line(minus)
  65. 22 / 77 - 28 / 77
  66. >>> plus = Node('+', left, Node('-', right))
  67. >>> print generate_line(plus)
  68. 22 / 77 - 28 / 77
  69. """
  70. if not root:
  71. return '<empty expression>'
  72. if root.is_leaf:
  73. return str(root)
  74. content = {}
  75. def construct_unary(node):
  76. op = node.title()
  77. value = node[0]
  78. # -a
  79. # -3 * 4
  80. # --a
  81. if value.is_leaf \
  82. or not ' ' in content[value] or pred(value) > 0:
  83. return op + content[value]
  84. # -(a + b)
  85. return '%s(%s)' % (op, content[value])
  86. def construct_nary(node):
  87. op = node.title()
  88. # N-ary operator
  89. node_pred = pred(node)
  90. sep = ' ' + op + ' '
  91. e = []
  92. for i, child in enumerate(node):
  93. exp = content[child]
  94. #if i and op == '+' and exp[:2] == '-(':
  95. # exp = '-' + exp[2:-1]
  96. # print 'exp:', exp
  97. # Check if there is a precedence conflict
  98. # If so, add parentheses
  99. child_pred = pred(child)
  100. if child.negated:
  101. # (-a) ^ b
  102. if op == '^' and not i:
  103. exp = '(' + exp + ')'
  104. elif child_pred < node_pred:
  105. exp = '(' + exp + ')'
  106. elif child_pred == node_pred:
  107. if i and (op != child.title() or op == '/' \
  108. or (op == '+' and child[1].negated)):
  109. exp = '(' + exp + ')'
  110. elif not i and op == '^':
  111. exp = '(' + exp + ')'
  112. e.append(exp)
  113. if op == '*':
  114. # Check if an explicit multiplication sign is nessecary
  115. left, right = node
  116. # Get the previous multiplication element if the arity is
  117. # greater than 2
  118. if left.title() == '*':
  119. left = left[1]
  120. # a * b -> ab
  121. # a * 2 -> a * 2
  122. # a * (b) -> a(b)
  123. # (a) * b -> (a)b
  124. # (a) * (b) -> (a)(b)
  125. # 2 * a -> 2a
  126. l = e[0][-1]
  127. r = e[1][0]
  128. left_simple = is_id(left) or is_int(left)
  129. if (r == '(' and left_simple) or (l == ')' and r != '-') \
  130. or (left_simple and r.isalpha()):
  131. sep = ''
  132. exp = sep.join(e)
  133. if node.negated and op not in ('*', '/', '^'):
  134. exp = '(' + exp + ')'
  135. return exp
  136. def construct_function(node):
  137. buf = []
  138. for child in node:
  139. buf.append(content[child])
  140. return '%s(%s)' % (node.title(), ', '.join(buf))
  141. # Traverse the expression tree and construct the mathematical expression in
  142. # the leafs and nodes in depth first order.
  143. for node in traverse_depth_first(root):
  144. if node.is_leaf:
  145. content[node] = str(node)
  146. else:
  147. arity = len(node)
  148. if is_operator(node):
  149. if arity == 1:
  150. content[node] = construct_unary(node)
  151. else:
  152. content[node] = construct_nary(node)
  153. else:
  154. content[node] = construct_function(node)
  155. # Add negations
  156. content[node] = '-' * node.negated + content[node]
  157. # Merge binary plus and unary minus signs into binary minus.
  158. return content[root].replace('+ -', '- ')