line.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from traverse import traverse_depth_first
  2. OPERATORS = (
  3. ('vv', ),
  4. ('^^', ),
  5. ('=', ),
  6. ('+', '-'),
  7. ('*', 'mod'),
  8. ('/', ),
  9. ('^', '_'),
  10. )
  11. NEG_PRED = 3
  12. MAX_PRED = len(OPERATORS)
  13. def is_operator(node):
  14. """
  15. Check if a given node is an operator (otherwise, it's a function).
  16. """
  17. label = node.title()
  18. return any(map(lambda x: label in x, OPERATORS))
  19. def pred(node):
  20. """
  21. Get the precedence of an operator node.
  22. """
  23. # Check binary and n-ary operators
  24. if not node.is_leaf and len(node) > 1:
  25. op = node.title()
  26. for i, group in enumerate(OPERATORS):
  27. if op in group:
  28. return i
  29. # Unary operator and leaves have highest precedence
  30. return MAX_PRED
  31. def is_id(node):
  32. return node.is_leaf and not node.title().isdigit()
  33. def is_int(node):
  34. return node.is_leaf and node.title().isdigit()
  35. def is_power(node):
  36. return not node.is_leaf and node.title() == '^'
  37. def generate_line(root):
  38. """
  39. Print an expression tree in a single text line. Where needed, add
  40. parentheses.
  41. >>> from node import Node, Leaf
  42. >>> l0, l1 = Leaf(1), Leaf(2)
  43. >>> print generate_line(l0)
  44. 1
  45. >>> plus = Node('+', l0, l1)
  46. >>> print generate_line(plus)
  47. 1 + 2
  48. >>> plus2 = Node('+', l0, l1)
  49. >>> times = Node('*', plus, plus2)
  50. >>> print generate_line(times)
  51. (1 + 2)(1 + 2)
  52. >>> l2 = Leaf(3)
  53. >>> uminus = Node('-', l2)
  54. >>> times = Node('*', plus, uminus)
  55. >>> print generate_line(times)
  56. (1 + 2) * -3
  57. >>> exp = Leaf('x')
  58. >>> inf = Leaf('oo')
  59. >>> minus_inf = Node('-', inf)
  60. >>> integral = Node('int', exp, minus_inf, inf)
  61. >>> print generate_line(integral)
  62. int(x, -oo, oo)
  63. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  64. >>> print generate_line(minus)
  65. 2 - -15x
  66. >>> left = Node('/', Leaf(22), Leaf(77))
  67. >>> right = Node('/', Leaf(28), Leaf(77))
  68. >>> minus = Node('-', left, right)
  69. >>> print generate_line(minus)
  70. 22 / 77 - 28 / 77
  71. >>> plus = Node('+', left, Node('-', right))
  72. >>> print generate_line(plus)
  73. 22 / 77 - 28 / 77
  74. """
  75. if not root:
  76. return '<empty expression>'
  77. if root.is_leaf:
  78. return str(root)
  79. content = {}
  80. def construct_unary(node):
  81. op = node.title()
  82. value = node[0]
  83. # -a
  84. # -3 * 4
  85. # --a
  86. if value.is_leaf \
  87. or not ' ' in content[value] or pred(value) > NEG_PRED:
  88. return op + content[value]
  89. # -(a + b)
  90. return '%s(%s)' % (op, content[value])
  91. def construct_nary(node):
  92. op = node.title()
  93. # N-ary operator
  94. node_pred = pred(node)
  95. sep = ' ' + op + ' '
  96. e = []
  97. for i, child in enumerate(node):
  98. exp = content[child]
  99. #if i and op == '+' and exp[:2] == '-(':
  100. # exp = '-' + exp[2:-1]
  101. # print 'exp:', exp
  102. # Check if there is a precedence conflict
  103. # If so, add parentheses
  104. child_pred = pred(child)
  105. if child.negated:
  106. # (-a) ^ b
  107. # -a ^ -b
  108. # (-a) * b
  109. # a * -b
  110. # (-a) / b
  111. if node_pred > NEG_PRED:
  112. exp = '(' + exp + ')'
  113. elif child_pred < node_pred:
  114. exp = '(' + exp + ')'
  115. elif child_pred == node_pred:
  116. if i and (op != child.title() or op == '/' \
  117. or (op == '+' and child[1].negated)):
  118. exp = '(' + exp + ')'
  119. elif not i and op == '^':
  120. exp = '(' + exp + ')'
  121. e.append(exp)
  122. if op == '*':
  123. # Check if an explicit multiplication sign is nessecary
  124. left, right = node
  125. # Get the previous multiplication element if the arity is
  126. # greater than 2
  127. if left.title() == '*':
  128. left = left[1]
  129. # a * b -> ab
  130. # a * 2 -> a * 2
  131. # a * (b) -> a(b)
  132. # (a) * b -> (a)b
  133. # (a) * (b) -> (a)(b)
  134. # 2 * a -> 2a
  135. l = e[0][-1]
  136. r = e[1][0]
  137. left_simple = is_id(left) or is_int(left)
  138. if (r in ('(', '[') and left_simple) or (l == ')' and r != '-') \
  139. or (left_simple and r.isalpha()):
  140. sep = ''
  141. exp = sep.join(e)
  142. if node.negated and op not in ('*', '/', '^'):
  143. exp = '(' + exp + ')'
  144. return exp
  145. def construct_function(node):
  146. buf = []
  147. for child in node:
  148. buf.append(content[child])
  149. return '%s(%s)' % (node.title(), ', '.join(buf))
  150. # Traverse the expression tree and construct the mathematical expression in
  151. # the leafs and nodes in depth first order.
  152. for node in traverse_depth_first(root):
  153. if node.is_leaf:
  154. content[node] = str(node)
  155. else:
  156. arity = len(node)
  157. if is_operator(node):
  158. if arity == 1:
  159. content[node] = construct_unary(node)
  160. else:
  161. content[node] = construct_nary(node)
  162. else:
  163. result = None
  164. if hasattr(node, 'construct_function'):
  165. children = [content[c] for c in node]
  166. result = node.construct_function(children)
  167. if result == None:
  168. result = construct_function(node)
  169. content[node] = result
  170. # Add negations
  171. content[node] = '-' * node.negated + content[node]
  172. # Merge binary plus and unary minus signs into binary minus.
  173. return content[root].replace('+ -', '- ')