line.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. def generate_line(root):
  2. """
  3. Print an expression tree in a single text line. Where needed, add
  4. parentheses.
  5. >>> from node import Node, Leaf
  6. >>> l0, l1 = Leaf(1), Leaf(2)
  7. >>> plus = Node('+', l0, l1)
  8. >>> print generate_line(plus)
  9. 1 + 2
  10. >>> plus2 = Node('+', l0, l1)
  11. >>> times = Node('*', plus, plus2)
  12. >>> print generate_line(times)
  13. (1 + 2) * (1 + 2)
  14. >>> l2 = Leaf(3)
  15. >>> uminus = Node('-', l2)
  16. >>> times = Node('*', plus, uminus)
  17. >>> print generate_line(times)
  18. (1 + 2) * -3
  19. >>> exp = Leaf('x')
  20. >>> inf = Leaf('oo')
  21. >>> minus_inf = Node('-', inf)
  22. >>> integral = Node('int', exp, minus_inf, inf)
  23. >>> print generate_line(integral)
  24. int(x, -oo, oo)
  25. """
  26. operators = [
  27. ('+', '-'),
  28. ('*', '/', 'mod'),
  29. ('^', )
  30. ]
  31. max_assoc = len(operators)
  32. def is_operator(node):
  33. """
  34. Check if a given node is an operator (otherwise, it's a function).
  35. """
  36. label = node.title()
  37. either = lambda a, b: a or b
  38. return reduce(either, map(lambda x: label in x, operators))
  39. def assoc(node):
  40. """
  41. Get the associativity of an operator node.
  42. """
  43. if not node.nodes and len(node) > 1:
  44. op = node.title()
  45. for i, group in enumerate(operators):
  46. if op in group:
  47. return i
  48. return max_assoc
  49. def traverse(node):
  50. """
  51. The expression tree is traversed using preorder traversal:
  52. 1. Visit the root
  53. 2. Traverse the subtrees in left-to-right order
  54. """
  55. if not node:
  56. return '<empty expression>'
  57. s = node.title()
  58. if not node.nodes:
  59. return s
  60. arity = len(node)
  61. if is_operator(node):
  62. if arity == 1:
  63. # Unary operator
  64. s += traverse(node[0])
  65. else:
  66. # N-ary operator
  67. node_assoc = assoc(node)
  68. e = []
  69. for child in node:
  70. exp = traverse(child)
  71. # Check if there is an assiociativity conflict.
  72. # If so, add parentheses
  73. if assoc(child) < node_assoc:
  74. exp = '(' + exp + ')'
  75. e.append(exp)
  76. s = (' ' + s + ' ').join(e)
  77. else:
  78. # Function
  79. s += '(' + ', '.join(map(traverse, node)) + ')'
  80. return s
  81. return traverse(root)