line.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from traverse import traverse_depth_first
  2. OPERATORS = [
  3. ('+', '-'),
  4. ('*', '/', 'mod'),
  5. ('^', )
  6. ]
  7. MAX_PRED = len(OPERATORS)
  8. def is_operator(node):
  9. """
  10. Check if a given node is an operator (otherwise, it's a function).
  11. """
  12. label = node.title()
  13. return any(map(lambda x: label in x, OPERATORS))
  14. def pred(node):
  15. """
  16. Get the precedence of an operator node.
  17. """
  18. # Check binary and n-ary operators
  19. if not node.is_leaf and len(node) > 1:
  20. op = node.title()
  21. for i, group in enumerate(OPERATORS):
  22. if op in group:
  23. return i
  24. # Unary operator and leaves have highest precedence
  25. return MAX_PRED
  26. def is_id(node):
  27. return node.is_leaf and not node.title().isdigit()
  28. def is_int(node):
  29. return node.is_leaf and node.title().isdigit()
  30. def generate_line(root):
  31. """
  32. Print an expression tree in a single text line. Where needed, add
  33. parentheses.
  34. >>> from node import Node, Leaf
  35. >>> l0, l1 = Leaf(1), Leaf(2)
  36. >>> print generate_line(l0)
  37. 1
  38. >>> plus = Node('+', l0, l1)
  39. >>> print generate_line(plus)
  40. 1 + 2
  41. >>> plus2 = Node('+', l0, l1)
  42. >>> times = Node('*', plus, plus2)
  43. >>> print generate_line(times)
  44. (1 + 2)(1 + 2)
  45. >>> l2 = Leaf(3)
  46. >>> uminus = Node('-', l2)
  47. >>> times = Node('*', plus, uminus)
  48. >>> print generate_line(times)
  49. (1 + 2) * -3
  50. >>> exp = Leaf('x')
  51. >>> inf = Leaf('oo')
  52. >>> minus_inf = Node('-', inf)
  53. >>> integral = Node('int', exp, minus_inf, inf)
  54. >>> print generate_line(integral)
  55. int(x, -oo, oo)
  56. >>> minus = Node('-', Leaf(2), Node('-', Node('*', Leaf(15), Leaf('x'))))
  57. >>> print generate_line(minus)
  58. 2 - -15x
  59. >>> left = Node('/', Leaf(22), Leaf(77))
  60. >>> right = Node('/', Leaf(28), Leaf(77))
  61. >>> minus = Node('-', left, right)
  62. >>> print generate_line(minus)
  63. 22 / 77 - 28 / 77
  64. >>> plus = Node('+', left, Node('-', right))
  65. >>> print generate_line(plus)
  66. 22 / 77 - 28 / 77
  67. """
  68. if not root:
  69. return '<empty expression>'
  70. if root.is_leaf:
  71. return '-' * root.negated + root.title()
  72. content = {}
  73. def construct_unary(node):
  74. op = node.title()
  75. value = node[0]
  76. # -a
  77. # -3 * 4
  78. # --a
  79. if value.is_leaf \
  80. or not ' ' in content[value] or pred(value) > 0:
  81. return op + content[value]
  82. # -(a + b)
  83. return '%s(%s)' % (op, content[value])
  84. def construct_nary(node):
  85. op = node.title()
  86. # N-ary operator
  87. node_pred = pred(node)
  88. sep = ' ' + op + ' '
  89. e = []
  90. for i, child in enumerate(node):
  91. exp = content[child]
  92. #if i and op == '+' and exp[:2] == '-(':
  93. # exp = '-' + exp[2:-1]
  94. # print 'exp:', exp
  95. # Check if there is a precedence conflict
  96. # If so, add parentheses
  97. child_pred = pred(child)
  98. if not child.negated and (child_pred < node_pred \
  99. or (i and child_pred == node_pred \
  100. and (op != child.title() \
  101. or (op == '+' and child[1].negated)))):
  102. exp = '(' + exp + ')'
  103. e.append(exp)
  104. if op == '*':
  105. # Check if an explicit multiplication sign is nessecary
  106. left, right = node
  107. # Get the previous multiplication element if the arity is
  108. # greater than 2
  109. if left.title() == '*':
  110. left = left[1]
  111. # a * b -> ab
  112. # a * 2 -> a * 2
  113. # a * (b) -> a(b)
  114. # (a) * b -> (a)b
  115. # (a) * (b) -> (a)(b)
  116. # 2 * a -> 2a
  117. left_paren = e[0][-1] == ')'
  118. right_paren = e[1][0] == '('
  119. if (is_id(left) or left_paren or is_int(left)) \
  120. and ((not right.negated and is_id(right)) or right_paren):
  121. sep = ''
  122. exp = sep.join(e)
  123. if node.negated and op not in ('*', '/'):
  124. exp = '(' + exp + ')'
  125. return exp
  126. def construct_function(node):
  127. buf = []
  128. for child in node:
  129. buf.append(content[child])
  130. return '%s(%s)' % (node.title(), ', '.join(buf))
  131. # Traverse the expression tree and construct the mathematical expression in
  132. # the leafs and nodes in depth first order.
  133. for node in traverse_depth_first(root):
  134. if node.is_leaf:
  135. content[node] = node.title()
  136. else:
  137. arity = len(node)
  138. if is_operator(node):
  139. if arity == 1:
  140. content[node] = construct_unary(node)
  141. else:
  142. content[node] = construct_nary(node)
  143. else:
  144. content[node] = construct_function(node)
  145. # Add negations
  146. content[node] = '-' * node.negated + content[node]
  147. # Merge binary plus and unary minus signs into binary minus.
  148. return content[root].replace('+ -', '- ')