node.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import os.path
  2. import sys
  3. sys.path.insert(0, os.path.realpath('external'))
  4. from graph_drawing.graph import generate_graph
  5. from graph_drawing.line import generate_line
  6. from graph_drawing.node import Node, Leaf
  7. TYPE_OPERATOR = 1
  8. TYPE_IDENTIFIER = 2
  9. TYPE_INTEGER = 4
  10. TYPE_FLOAT = 8
  11. TYPE_NUMERIC = TYPE_INTEGER | TYPE_FLOAT
  12. # Unary
  13. OP_NEG = 1
  14. # Binary
  15. OP_ADD = 2
  16. OP_SUB = 3
  17. OP_MUL = 4
  18. OP_DIV = 5
  19. OP_POW = 6
  20. OP_MOD = 7
  21. # N-ary (functions)
  22. OP_INT = 8
  23. OP_EXPAND = 9
  24. TYPE_MAP = {
  25. int: TYPE_INTEGER,
  26. float: TYPE_FLOAT,
  27. str: TYPE_IDENTIFIER,
  28. }
  29. OP_MAP = {
  30. '+': OP_ADD,
  31. # Either substitution or negation. Skip the operator sign in 'x' (= 2).
  32. '-': lambda x: OP_SUB if len(x) > 2 else OP_NEG,
  33. '*': OP_MUL,
  34. '/': OP_DIV,
  35. '^': OP_POW,
  36. 'mod': OP_MOD,
  37. 'int': OP_INT,
  38. 'expand': OP_EXPAND,
  39. }
  40. class ExpressionNode(Node):
  41. def __init__(self, *args, **kwargs):
  42. super(ExpressionNode, self).__init__(*args, **kwargs)
  43. self.type = TYPE_OPERATOR
  44. self.op = OP_MAP[args[0]]
  45. if hasattr(self.op, '__call__'):
  46. self.op = self.op(args)
  47. def __str__(self): # pragma: nocover
  48. return generate_line(self)
  49. def graph(self): # pragma: nocover
  50. return generate_graph(self)
  51. def replace(self, node):
  52. pos = self.parent.nodes.index(self)
  53. self.parent.nodes[pos] = node
  54. node.parent = self.parent
  55. self.parent = None
  56. def is_power(self):
  57. return self.op == OP_POW
  58. def is_nary(self):
  59. return self.op in [OP_ADD, OP_SUB, OP_MUL]
  60. def get_order(self):
  61. if self.is_power() and self[0].is_identifier() \
  62. and isinstance(self[1], Leaf):
  63. return (self[0].value, self[1].value, 1)
  64. for n0, n1 in [(0, 1), (1, 0)]:
  65. if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \
  66. and self[n1].is_power():
  67. coeff, power = self
  68. if power[0].is_identifier() and isinstance(power[1], Leaf):
  69. return (power[0].value, power[1].value, coeff.value)
  70. def get_scope(self):
  71. scope = []
  72. for child in self:
  73. if not isinstance(child, Leaf) and child.op == self.op:
  74. scope += child.get_scope()
  75. else:
  76. scope.append(child)
  77. return scope
  78. class ExpressionLeaf(Leaf):
  79. def __init__(self, *args, **kwargs):
  80. super(ExpressionLeaf, self).__init__(*args, **kwargs)
  81. self.type = TYPE_MAP[type(args[0])]
  82. def get_order(self):
  83. if self.is_identifier():
  84. return (self.value, 1, 1)
  85. def replace(self, node):
  86. if not hasattr(self, 'parent'):
  87. return
  88. pos = self.parent.nodes.index(self)
  89. self.parent.nodes[pos] = node
  90. node.parent = self.parent
  91. self.parent = None
  92. def is_identifier(self):
  93. return self.type & TYPE_IDENTIFIER
  94. def is_int(self):
  95. return self.type & TYPE_INTEGER
  96. def is_float(self):
  97. return self.type & TYPE_FLOAT
  98. def is_numeric(self):
  99. return self.type & TYPE_NUMERIC
  100. if __name__ == '__main__': # pragma: nocover
  101. l0 = ExpressionLeaf(3)
  102. l1 = ExpressionLeaf(4)
  103. l2 = ExpressionLeaf(5)
  104. l3 = ExpressionLeaf(7)
  105. n0 = ExpressionNode('+', l0, l1)
  106. n1 = ExpressionNode('+', l2, l3)
  107. n2 = ExpressionNode('*', n0, n1)
  108. print n2
  109. N = ExpressionNode
  110. def rewrite_multiply(node):
  111. a, b = node[0]
  112. c, d = node[1]
  113. ac = N('*', a, c)
  114. ad = N('*', a, d)
  115. bc = N('*', b, c)
  116. bd = N('*', b, d)
  117. res = N('+', N('+', N('+', ac, ad), bc), bd)
  118. return res
  119. possibilities = [
  120. (n0, lambda (x, y): ExpressionLeaf(x.value + y.value)),
  121. (n1, lambda (x, y): ExpressionLeaf(x.value + y.value)),
  122. (n2, rewrite_multiply),
  123. ]
  124. print '\n--- after rule 2 ---\n'
  125. n_, method = possibilities[2]
  126. new = method(n_)
  127. print new
  128. print '\n--- original graph ---\n'
  129. print n2
  130. print '\n--- apply rule 0 ---\n'
  131. n_, method = possibilities[0]
  132. new = method(n_)
  133. n_.replace(new)
  134. print n2
  135. # Revert rule 0
  136. new.replace(n_)
  137. print '\n--- apply rule 1 ---\n'
  138. n_, method = possibilities[1]
  139. new = method(n_)
  140. n_.replace(new)
  141. print n2