line.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from traverse import traverse_depth_first
  2. OPERATORS = (
  3. ('vv', ),
  4. ('^^', ),
  5. ('=', ),
  6. ('+', '-'),
  7. ('*', '/', 'mod'),
  8. ('^', '_'),
  9. )
  10. NEG_PRED = 3
  11. MAX_PRED = len(OPERATORS)
  12. def is_operator(node):
  13. """
  14. Check if a given node is an operator (otherwise, it's a function).
  15. """
  16. label = node.title()
  17. return any(map(lambda x: label in x, OPERATORS))
  18. def pred(node):
  19. """
  20. Get the precedence of an operator node.
  21. """
  22. # Check binary and n-ary operators
  23. if not node.is_leaf and len(node) > 1:
  24. op = node.title()
  25. for i, group in enumerate(OPERATORS):
  26. if op in group:
  27. return i
  28. # Unary operator and leaves have highest precedence
  29. return MAX_PRED
  30. def is_id(node):
  31. return node.is_leaf and not node.title().isdigit()
  32. def is_int(node):
  33. return node.is_leaf and node.title().isdigit()
  34. def is_power(node):
  35. return not node.is_leaf and node.title() == '^'
  36. def generate_line(root):
  37. """
  38. Print an expression tree in a single text line. Where needed, add
  39. parentheses.
  40. >>> from node import Node, Leaf
  41. >>> l0, l1 = Leaf(1), Leaf(2)
  42. >>> print generate_line(l0)
  43. 1
  44. >>> plus = Node('+', l0, l1)
  45. >>> print generate_line(plus)
  46. 1 + 2
  47. >>> plus2 = Node('+', l0, l1)
  48. >>> times = Node('*', plus, plus2)
  49. >>> print generate_line(times)
  50. (1 + 2)(1 + 2)
  51. >>> l2 = Leaf(3)
  52. >>> uminus = Node('-', l2)
  53. >>> times = Node('*', plus, uminus)
  54. >>> print generate_line(times)
  55. (1 + 2) * -3
  56. >>> exp = Leaf('x')
  57. >>> inf = Leaf('oo')
  58. >>> minus_inf = Node('-', inf)
  59. >>> integral = Node('int', exp, minus_inf, inf)
  60. >>> print generate_line(integral)
  61. int(x, -oo, oo)
  62. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  63. >>> print generate_line(minus)
  64. 2 - -15x
  65. >>> left = Node('/', Leaf(22), Leaf(77))
  66. >>> right = Node('/', Leaf(28), Leaf(77))
  67. >>> minus = Node('-', left, right)
  68. >>> print generate_line(minus)
  69. 22 / 77 - 28 / 77
  70. >>> plus = Node('+', left, Node('-', right))
  71. >>> print generate_line(plus)
  72. 22 / 77 - 28 / 77
  73. """
  74. if not root:
  75. return '<empty expression>'
  76. if root.is_leaf:
  77. return str(root)
  78. content = {}
  79. def construct_unary(node):
  80. op = node.title()
  81. value = node[0]
  82. # -a
  83. # -3 * 4
  84. # --a
  85. if value.is_leaf \
  86. or not ' ' in content[value] or pred(value) > NEG_PRED:
  87. return op + content[value]
  88. # -(a + b)
  89. return '%s(%s)' % (op, content[value])
  90. def construct_nary(node):
  91. op = node.title()
  92. # N-ary operator
  93. node_pred = pred(node)
  94. sep = ' ' + op + ' '
  95. e = []
  96. for i, child in enumerate(node):
  97. exp = content[child]
  98. #if i and op == '+' and exp[:2] == '-(':
  99. # exp = '-' + exp[2:-1]
  100. # print 'exp:', exp
  101. # Check if there is a precedence conflict
  102. # If so, add parentheses
  103. child_pred = pred(child)
  104. if child.negated:
  105. # (-a) ^ b
  106. # -a ^ -b
  107. # (-a) * b
  108. # a * -b
  109. # (-a) / b
  110. if node_pred > NEG_PRED:
  111. exp = '(' + exp + ')'
  112. elif child_pred < node_pred:
  113. exp = '(' + exp + ')'
  114. elif child_pred == node_pred:
  115. if i and (op != child.title() or op == '/' \
  116. or (op == '+' and child[1].negated)):
  117. exp = '(' + exp + ')'
  118. elif not i and op == '^':
  119. exp = '(' + exp + ')'
  120. e.append(exp)
  121. if op == '*':
  122. # Check if an explicit multiplication sign is nessecary
  123. left, right = node
  124. # Get the previous multiplication element if the arity is
  125. # greater than 2
  126. if left.title() == '*':
  127. left = left[1]
  128. # a * b -> ab
  129. # a * 2 -> a * 2
  130. # a * (b) -> a(b)
  131. # (a) * b -> (a)b
  132. # (a) * (b) -> (a)(b)
  133. # 2 * a -> 2a
  134. l = e[0][-1]
  135. r = e[1][0]
  136. left_simple = is_id(left) or is_int(left)
  137. if (r in ('(', '[') and left_simple) or (l == ')' and r != '-') \
  138. or (left_simple and r.isalpha()):
  139. sep = ''
  140. exp = sep.join(e)
  141. if node.negated and op not in ('*', '/', '^'):
  142. exp = '(' + exp + ')'
  143. return exp
  144. def construct_function(node):
  145. buf = []
  146. for child in node:
  147. buf.append(content[child])
  148. return '%s(%s)' % (node.title(), ', '.join(buf))
  149. # Traverse the expression tree and construct the mathematical expression in
  150. # the leafs and nodes in depth first order.
  151. for node in traverse_depth_first(root):
  152. if node.is_leaf:
  153. content[node] = str(node)
  154. else:
  155. arity = len(node)
  156. if is_operator(node):
  157. if arity == 1:
  158. content[node] = construct_unary(node)
  159. else:
  160. content[node] = construct_nary(node)
  161. else:
  162. result = None
  163. if hasattr(node, 'construct_function'):
  164. children = [content[c] for c in node]
  165. result = node.construct_function(children)
  166. if result == None:
  167. result = construct_function(node)
  168. content[node] = result
  169. # Add negations
  170. content[node] = '-' * node.negated + content[node]
  171. # Merge binary plus and unary minus signs into binary minus.
  172. return content[root].replace('+ -', '- ')