node.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # vim: set fileencoding=utf-8 :
  2. import os.path
  3. import sys
  4. import copy
  5. sys.path.insert(0, os.path.realpath('external'))
  6. from graph_drawing.graph import generate_graph
  7. from graph_drawing.line import generate_line
  8. from graph_drawing.node import Node, Leaf
  9. TYPE_OPERATOR = 1
  10. TYPE_IDENTIFIER = 2
  11. TYPE_INTEGER = 4
  12. TYPE_FLOAT = 8
  13. # Unary
  14. OP_NEG = 1
  15. # Binary
  16. OP_ADD = 2
  17. OP_SUB = 3
  18. OP_MUL = 4
  19. OP_DIV = 5
  20. OP_POW = 6
  21. OP_MOD = 7
  22. # N-ary (functions)
  23. OP_INT = 8
  24. OP_EXPAND = 9
  25. OP_COMMA = 10
  26. OP_SQRT = 11
  27. TYPE_MAP = {
  28. int: TYPE_INTEGER,
  29. float: TYPE_FLOAT,
  30. str: TYPE_IDENTIFIER,
  31. }
  32. OP_MAP = {
  33. '+': OP_ADD,
  34. # Either substraction or negation. Skip the operator sign in 'x' (= 2).
  35. '-': lambda x: OP_SUB if len(x) > 2 else OP_NEG,
  36. '*': OP_MUL,
  37. '/': OP_DIV,
  38. '^': OP_POW,
  39. 'mod': OP_MOD,
  40. 'int': OP_INT,
  41. 'expand': OP_EXPAND,
  42. 'sqrt': OP_SQRT,
  43. ',': OP_COMMA,
  44. }
  45. def to_expression(obj):
  46. return obj if isinstance(obj, ExpressionBase) else ExpressionLeaf(obj)
  47. class ExpressionBase(object):
  48. def __init__(self, *args, **kwargs):
  49. self.negated = 0
  50. def clone(self):
  51. return copy.deepcopy(self)
  52. def __lt__(self, other):
  53. """
  54. Comparison between this expression{node,leaf} and another
  55. expression{node,leaf}. This comparison will return True if this
  56. instance has less value than the other expression{node,leaf}.
  57. Otherwise, False is returned.
  58. The comparison is based on the following conditions:
  59. 1. Both are leafs. String comparison of the value is used.
  60. 2. This is a leaf and other is a node. This leaf has less value, thus
  61. True is returned.
  62. 3. This is a node and other is a leaf. This leaf has more value, thus
  63. False is returned.
  64. 4. Both are nodes. Compare the polynome properties of the nodes. True
  65. is returned if this node's root property is less than other's root
  66. property, or this node's exponent property is less than other's
  67. exponent property, or this node's coefficient property is less than
  68. other's coefficient property. Otherwise, False is returned.
  69. """
  70. if self.is_leaf:
  71. if other.is_leaf:
  72. # Both are leafs, string compare the value.
  73. return str(self.value) < str(other.value)
  74. # Self is a leaf, thus has less value than an expression node.
  75. return True
  76. if self.is_op(OP_NEG) and self[0].is_leaf:
  77. if other.is_leaf:
  78. # Both are leafs, string compare the value.
  79. return ('-' + str(self.value)) < str(other.value)
  80. if other.is_op(OP_NEG) and other[0].is_leaf:
  81. return ('-' + str(self.value)) < ('-' + str(other.value))
  82. # Self is a leaf, thus has less value than an expression node.
  83. return True
  84. if other.is_leaf:
  85. # Self is an expression node, and the other is a leaf. Thus, other
  86. # is greater than self.
  87. return False
  88. # Both are nodes, compare the polynome properties.
  89. s_coeff, s_root, s_exp = self.extract_polynome_properties()
  90. o_coeff, o_root, o_exp = other.extract_polynome_properties()
  91. return s_root < o_root or s_exp < o_exp or s_coeff < o_coeff
  92. def is_op(self, op):
  93. return not self.is_leaf and self.op == op
  94. def is_op_or_negated(self, op):
  95. if self.is_leaf:
  96. return False
  97. if self.op == OP_NEG:
  98. return self[0].is_op(op)
  99. return self.op == op
  100. def is_leaf_or_negated(self):
  101. if self.is_leaf:
  102. return True
  103. if self.is_op(OP_NEG):
  104. return self[0].is_leaf
  105. return False
  106. def is_power(self):
  107. return not self.is_leaf and self.op == OP_POW
  108. def is_nary(self):
  109. return not self.is_leaf and self.op in [OP_ADD, OP_SUB, OP_MUL]
  110. def is_identifier(self):
  111. return self.type == TYPE_IDENTIFIER
  112. def is_int(self):
  113. return self.type == TYPE_INTEGER
  114. def is_float(self):
  115. return self.type == TYPE_FLOAT
  116. def is_numeric(self):
  117. return self.type & (TYPE_FLOAT | TYPE_INTEGER)
  118. def __add__(self, other):
  119. return ExpressionNode('+', self, to_expression(other))
  120. def __sub__(self, other):
  121. return ExpressionNode('-', self, to_expression(other))
  122. def __mul__(self, other):
  123. return ExpressionNode('*', self, to_expression(other))
  124. def __div__(self, other):
  125. return ExpressionNode('/', self, to_expression(other))
  126. def __pow__(self, other):
  127. return ExpressionNode('^', self, to_expression(other))
  128. def __neg__(self):
  129. self.negated += 1
  130. return self
  131. class ExpressionNode(Node, ExpressionBase):
  132. def __init__(self, *args, **kwargs):
  133. super(ExpressionNode, self).__init__(*args, **kwargs)
  134. self.type = TYPE_OPERATOR
  135. self.op = OP_MAP[args[0]]
  136. if hasattr(self.op, '__call__'):
  137. self.op = self.op(args)
  138. def __str__(self): # pragma: nocover
  139. return generate_line(self)
  140. def __eq__(self, other):
  141. """
  142. Check strict equivalence.
  143. """
  144. if isinstance(other, ExpressionNode):
  145. return self.op == other.op and self.nodes == other.nodes
  146. return False
  147. def substitute(self, old_child, new_child):
  148. self.nodes[self.nodes.index(old_child)] = new_child
  149. def graph(self): # pragma: nocover
  150. return generate_graph(self)
  151. def extract_polynome_properties(self):
  152. """
  153. Extract polynome properties into tuple format: (coefficient, root,
  154. exponent). Thus: c * r ^ e will be extracted into the tuple (c, r, e).
  155. This function will normalize the expression before extracting the
  156. properties. Therefore, the expression r ^ e * c results the same tuple
  157. (c, r, e) as the expression c * r ^ e.
  158. >>> from src.node import ExpressionNode as N, ExpressionLeaf as L
  159. >>> c, r, e = L('c'), L('r'), L('e')
  160. >>> n1 = N('*', c, N('^', r, e))
  161. >>> n1.extract_polynome()
  162. (c, r, e)
  163. >>> n2 = N('*', N('^', r, e), c)
  164. >>> n2.extract_polynome()
  165. (c, r, e)
  166. >>> n3 = N('-', r)
  167. >>> n3.extract_polynome()
  168. (1, -r, 1)
  169. """
  170. # TODO: change "get_polynome" -> "extract_polynome".
  171. # TODO: change retval of c * r ^ e to (c, r, e).
  172. # was: (root, exponent, coefficient, literal_exponent)
  173. # rule: r ^ e -> (1, r, e)
  174. if self.is_power():
  175. return (ExpressionLeaf(1), self[0], self[1])
  176. # rule: -r -> (1, r, 1)
  177. if self.is_op(OP_NEG):
  178. return (ExpressionLeaf(1), -self[0], ExpressionLeaf(1))
  179. if self.op != OP_MUL:
  180. return
  181. # rule: 3 * 7 ^ e | 'a' * 'b' ^ e
  182. # expression: c * r ^ e ; tree:
  183. #
  184. # *
  185. # ╭┴───╮
  186. # c ^
  187. # ╭─┴╮
  188. # r e
  189. #
  190. # rule: c * r ^ e | (r ^ e) * c
  191. for i, j in ((0, 1), (1, 0)):
  192. if self[j].is_power():
  193. return (self[i], self[j][0], self[j][1])
  194. # Normalize c * r and r * c -> c * r. Otherwise, the tuple will not
  195. # match if the order of the expression is different. Example:
  196. # r ^ e * c == c * r ^ e
  197. # without normalization, those expressions will not match.
  198. #
  199. # rule: c * r | r * c
  200. if self[0] < self[1]:
  201. return (self[0], self[1], ExpressionLeaf(1))
  202. return (self[1], self[0], ExpressionLeaf(1))
  203. def equals(self, other):
  204. """
  205. Perform a non-strict equivalence check between two nodes:
  206. - If the other node is a leaf, it cannot be equal to this node.
  207. - If their operators differ, the nodes are not equal.
  208. - If both nodes are additions or both are multiplications, match each
  209. node in one scope to one in the other (an injective relationship).
  210. Any difference in order of the scopes is irrelevant.
  211. - If both nodes are divisions, the nominator and denominator have to be
  212. non-strictly equal.
  213. """
  214. if not other.is_op(self.op):
  215. # FIXME: this is if-clause is a problem. To fix this problem
  216. # permanently, normalize ("x * -1" -> "-1x") before comparing to
  217. # the other node.
  218. return False
  219. if self.op in (OP_ADD, OP_MUL):
  220. s0 = Scope(self)
  221. s1 = set(Scope(other))
  222. # Scopes sould be of equal size
  223. if len(s0) != len(s1):
  224. return False
  225. # Each node in one scope should have an image node in the other
  226. matched = set()
  227. for n0 in s0:
  228. found = False
  229. for n1 in s1 - matched:
  230. if n0.equals(n1):
  231. found = True
  232. matched.add(n1)
  233. break
  234. if not found:
  235. return False
  236. else:
  237. # Check if all children are non-strictly equal, preserving order
  238. for i, child in enumerate(self):
  239. if not child.equals(other[i]):
  240. return False
  241. return True
  242. class ExpressionLeaf(Leaf, ExpressionBase):
  243. def __init__(self, *args, **kwargs):
  244. super(ExpressionLeaf, self).__init__(*args, **kwargs)
  245. self.type = TYPE_MAP[type(args[0])]
  246. def __eq__(self, other):
  247. """
  248. Check strict equivalence.
  249. """
  250. other_type = type(other)
  251. if other_type in TYPE_MAP:
  252. return TYPE_MAP[other_type] == self.type and self.value == other
  253. return other.type == self.type and self.value == other.value
  254. def equals(self, other):
  255. """
  256. Check non-strict equivalence.
  257. Between leaves, this is the same as strict equivalence.
  258. """
  259. return self == other
  260. def extract_polynome_properties(self):
  261. """
  262. An expression leaf will return the polynome tuple (1, r, 1), where r is
  263. the leaf itself. See also the method extract_polynome_properties in
  264. ExpressionBase.
  265. """
  266. # rule: 1 * r ^ 1 -> (1, r, 1)
  267. return (ExpressionLeaf(1), self, ExpressionLeaf(1))
  268. class Scope(object):
  269. def __init__(self, node):
  270. self.node = node
  271. self.nodes = get_scope(node)
  272. def __getitem__(self, key):
  273. return self.nodes[key]
  274. def __setitem__(self, key, value):
  275. self.nodes[key] = value
  276. def __len__(self):
  277. return len(self.nodes)
  278. def __iter__(self):
  279. return iter(self.nodes)
  280. def remove(self, node, replacement=None):
  281. if node.is_leaf:
  282. node_cmp = hash(node)
  283. else:
  284. node_cmp = node
  285. for i, n in enumerate(self.nodes):
  286. if n.is_leaf:
  287. n_cmp = hash(n)
  288. else:
  289. n_cmp = n
  290. if n_cmp == node_cmp:
  291. if replacement != None:
  292. self[i] = replacement
  293. else:
  294. del self.nodes[i]
  295. return
  296. raise ValueError('Node "%s" is not in the scope of "%s".'
  297. % (node, self.node))
  298. def as_nary_node(self):
  299. return nary_node(self.node.value, self.nodes)
  300. def nary_node(operator, scope):
  301. """
  302. Create a binary expression tree for an n-ary operator. Takes the operator
  303. and a list of expression nodes as arguments.
  304. """
  305. if len(scope) == 1:
  306. return scope[0]
  307. return ExpressionNode(operator, nary_node(operator, scope[:-1]), scope[-1])
  308. def get_scope(node):
  309. """
  310. Find all n nodes within the n-ary scope of an operator node.
  311. """
  312. scope = []
  313. for child in node:
  314. if child.is_op(node.op):
  315. scope += get_scope(child)
  316. else:
  317. scope.append(child)
  318. return scope