line.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from node import Node
  2. def generate_line(root):
  3. """
  4. Print an expression tree in a single text line. Where needed, add
  5. parentheses.
  6. >>> from node import Leaf
  7. >>> l0, l1 = Leaf(1), Leaf(2)
  8. >>> plus = Node('+', l0, l1)
  9. >>> print generate_line(plus)
  10. 1 + 2
  11. >>> plus2 = Node('+', l0, l1)
  12. >>> times = Node('*', plus, plus2)
  13. >>> print generate_line(times)
  14. (1 + 2) * (1 + 2)
  15. >>> l2 = Leaf(3)
  16. >>> uminus = Node('-', l2)
  17. >>> times = Node('*', plus, uminus)
  18. >>> print generate_line(times)
  19. (1 + 2) * -3
  20. >>> exp = Leaf('x')
  21. >>> inf = Leaf('oo')
  22. >>> minus_inf = Node('-', inf)
  23. >>> integral = Node('int', exp, minus_inf, inf)
  24. >>> print generate_line(integral)
  25. int(x, -oo, oo)
  26. """
  27. # FIXME: Binding order
  28. operators = [
  29. ('+', '-'),
  30. ('mod', ),
  31. ('*', '/'),
  32. ('^', )
  33. ]
  34. max_assoc = len(operators)
  35. def is_operator(node):
  36. """
  37. Check if a given node is an operator (otherwise, it's a function).
  38. """
  39. label = node.title()
  40. either = lambda a, b: a or b
  41. return reduce(either, map(lambda x: label in x, operators))
  42. def assoc(node):
  43. """
  44. Get the associativity of an operator node.
  45. """
  46. if isinstance(node, Node) and len(node) > 1:
  47. op = node.title()
  48. for i, group in enumerate(operators):
  49. if op in group:
  50. return i
  51. return max_assoc
  52. def traverse(node):
  53. """
  54. The expression tree is traversed using preorder traversal:
  55. 1. Visit the root
  56. 2. Traverse the subtrees in left-to-right order
  57. """
  58. s = node.title()
  59. if not isinstance(node, Node):
  60. return s
  61. arity = len(node)
  62. if is_operator(node):
  63. if arity == 1:
  64. # Unary expression
  65. s += traverse(node[0])
  66. elif arity == 2:
  67. # Binary expression
  68. left, right = map(traverse, node)
  69. # Check if there is an assiociativity conflict on either side.
  70. # If so, add parentheses
  71. node_assoc = assoc(node)
  72. if assoc(node[0]) < node_assoc:
  73. left = '(' + left + ')'
  74. if assoc(node[1]) < node_assoc:
  75. right = '(' + right + ')'
  76. s = left + ' ' + s + ' ' + right
  77. else: # pragma: nocover
  78. raise ValueError('arity = %d is currently not supported.' \
  79. % arity)
  80. else:
  81. # Function
  82. s += '(' + ', '.join(map(traverse, node)) + ')'
  83. return s
  84. return traverse(root)