line.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from node import Leaf
  2. def generate_line(root):
  3. """
  4. Print an expression tree in a single text line. Where needed, add
  5. parentheses.
  6. >>> from node import Node, Leaf
  7. >>> l0, l1 = Leaf(1), Leaf(2)
  8. >>> plus = Node('+', l0, l1)
  9. >>> print generate_line(plus)
  10. 1 + 2
  11. >>> plus2 = Node('+', l0, l1)
  12. >>> times = Node('*', plus, plus2)
  13. >>> print generate_line(times)
  14. (1 + 2)(1 + 2)
  15. >>> l2 = Leaf(3)
  16. >>> uminus = Node('-', l2)
  17. >>> times = Node('*', plus, uminus)
  18. >>> print generate_line(times)
  19. (1 + 2) * -3
  20. >>> exp = Leaf('x')
  21. >>> inf = Leaf('oo')
  22. >>> minus_inf = Node('-', inf)
  23. >>> integral = Node('int', exp, minus_inf, inf)
  24. >>> print generate_line(integral)
  25. int(x, -oo, oo)
  26. """
  27. operators = [
  28. ('+', '-'),
  29. ('*', '/', 'mod'),
  30. ('^', )
  31. ]
  32. max_pred = len(operators)
  33. def is_operator(node):
  34. """
  35. Check if a given node is an operator (otherwise, it's a function).
  36. """
  37. label = node.title()
  38. either = lambda a, b: a or b
  39. return reduce(either, map(lambda x: label in x, operators))
  40. def pred(node):
  41. """
  42. Get the precedence of an operator node.
  43. """
  44. # Check binary and n-ary operators
  45. if not isinstance(node, Leaf) and len(node) > 1:
  46. op = node.title()
  47. for i, group in enumerate(operators):
  48. if op in group:
  49. return i
  50. # Unary operator and leaves have highest precedence
  51. return max_pred
  52. def traverse(node):
  53. """
  54. The expression tree is traversed using preorder traversal:
  55. 1. Visit the root
  56. 2. Traverse the subtrees in left-to-right order
  57. """
  58. if not node:
  59. return '<empty expression>'
  60. op = node.title()
  61. if not node.nodes:
  62. return op
  63. arity = len(node)
  64. if is_operator(node):
  65. if arity == 1:
  66. # Unary operator
  67. sub = node[0]
  68. sub_exp = traverse(sub)
  69. # Negated sub-expressions with spaces in them should be
  70. # enclosed in parentheses, unless they have a higher precedence
  71. # than subtraction are rewritten to a factor of a subtraction:
  72. # -(1 + 2)
  73. # -(1 - 2)
  74. # -4a
  75. # -(4 * 5)
  76. # 1 - 4 * 5
  77. # 1 + -(4 * 5) -> 1 - 4 * 5
  78. if ' ' in sub_exp and not (not isinstance(sub, Leaf) \
  79. and hasattr(node, 'marked_negation')
  80. and pred(sub) > 0):
  81. sub_exp = '(' + sub_exp + ')'
  82. result = op + sub_exp
  83. else:
  84. # N-ary operator
  85. node_pred = pred(node)
  86. result = ''
  87. sep = ' ' + op + ' '
  88. e = []
  89. # Mark added and subtracted negations for later use when adding
  90. # parentheses
  91. if op in ('+', '-'):
  92. for child in node:
  93. if child.title() == '-' and len(child) == 1:
  94. child.marked_negation = True
  95. for i, child in enumerate(node):
  96. exp = traverse(child)
  97. # Check if there is a precedence conflict
  98. # If so, add parentheses
  99. child_pred = pred(child)
  100. if child_pred < node_pred or \
  101. (i and child_pred == node_pred \
  102. and op != child.title()):
  103. exp = '(' + exp + ')'
  104. e.append(exp)
  105. if op == '*':
  106. # Check if an explicit multiplication sign is nessecary
  107. left, right = node
  108. # Get the previous multiplication element if the arity is
  109. # greater than 2
  110. if left.title() == '*':
  111. left = left[1]
  112. # a * b -> ab
  113. # a * 2 -> a * 2
  114. # a * (b) -> a(b)
  115. # (a) * b -> (a)b
  116. # (a) * (b) -> (a)(b)
  117. # 2 * a -> 2a
  118. left_id = is_id(left)
  119. right_id = is_id(right)
  120. left_paren = e[0][-1] == ')'
  121. right_paren = e[1][0] == '('
  122. left_int = is_int(left)
  123. if (left_id or left_paren or left_int) \
  124. and (right_id or right_paren):
  125. sep = ''
  126. result += sep.join(e)
  127. else:
  128. # Function call
  129. result = op + '(' + ', '.join(map(traverse, node)) + ')'
  130. # An addition with negation can be written as a subtraction, e.g.:
  131. # 1 + -2 -> 1 - 2
  132. return result.replace('+ -', '- ')
  133. return traverse(root)
  134. def is_negation(node):
  135. if isinstance(node, Leaf):
  136. return False
  137. return node.title() == '-' and len(node) == 1
  138. def is_id(node):
  139. if is_negation(node):
  140. return is_id(node[0])
  141. return isinstance(node, Leaf) and not node.title().isdigit()
  142. def is_int(node):
  143. if is_negation(node):
  144. return is_int(node[0])
  145. return isinstance(node, Leaf) and node.title().isdigit()