|
|
@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
|
|
|
return (self[0], self[1], ExpressionLeaf(1))
|
|
|
return (self[1], self[0], ExpressionLeaf(1))
|
|
|
|
|
|
- def get_scope(self):
|
|
|
- """
|
|
|
- Find all n nodes within the n-ary scope of this operator.
|
|
|
- """
|
|
|
- scope = []
|
|
|
- #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
|
|
|
-
|
|
|
- # TODO: what to do with OP_SUB and OP_ADD in get_scope?
|
|
|
-
|
|
|
- for child in self:
|
|
|
- if not child.is_leaf() and child.op == self.op:
|
|
|
- scope += child.get_scope()
|
|
|
- else:
|
|
|
- scope.append(child)
|
|
|
-
|
|
|
- return scope
|
|
|
-
|
|
|
def equals(self, other):
|
|
|
"""
|
|
|
Perform a non-strict equivalence check between two nodes:
|
|
|
@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
|
|
|
return False
|
|
|
|
|
|
if self.op in (OP_ADD, OP_MUL):
|
|
|
- s0 = self.get_scope()
|
|
|
- s1 = set(other.get_scope())
|
|
|
+ s0 = Scope(self)
|
|
|
+ s1 = set(Scope(other))
|
|
|
|
|
|
# Scopes sould be of equal size
|
|
|
if len(s0) != len(s1):
|
|
|
@@ -354,3 +337,72 @@ class ExpressionLeaf(Leaf, ExpressionBase):
|
|
|
"""
|
|
|
# rule: 1 * r ^ 1 -> (1, r, 1)
|
|
|
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
|
|
|
+
|
|
|
+
|
|
|
+class Scope(object):
|
|
|
+
|
|
|
+ def __init__(self, node):
|
|
|
+ self.node = node
|
|
|
+ self.nodes = get_scope(node)
|
|
|
+
|
|
|
+ def __getitem__(self, key):
|
|
|
+ return self.nodes[key]
|
|
|
+
|
|
|
+ def __setitem__(self, key, value):
|
|
|
+ self.nodes[key] = value
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.nodes)
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ return iter(self.nodes)
|
|
|
+
|
|
|
+ def remove(self, node, replacement=None):
|
|
|
+ if node.is_leaf():
|
|
|
+ node_cmp = hash(node)
|
|
|
+ else:
|
|
|
+ node_cmp = node
|
|
|
+
|
|
|
+ for i, n in enumerate(self.nodes):
|
|
|
+ if n.is_leaf():
|
|
|
+ n_cmp = hash(n)
|
|
|
+ else:
|
|
|
+ n_cmp = n
|
|
|
+
|
|
|
+ if n_cmp == node_cmp:
|
|
|
+ if replacement != None:
|
|
|
+ self[i] = replacement
|
|
|
+ else:
|
|
|
+ del self.nodes[i]
|
|
|
+
|
|
|
+ return
|
|
|
+
|
|
|
+ raise ValueError('Node "%s" is not in the scope of "%s".'
|
|
|
+ % (node, self.node))
|
|
|
+
|
|
|
+ def as_nary_node(self):
|
|
|
+ return nary_node(self.node.value, self.nodes)
|
|
|
+
|
|
|
+
|
|
|
+def nary_node(operator, scope):
|
|
|
+ """
|
|
|
+ Create a binary expression tree for an n-ary operator. Takes the operator
|
|
|
+ and a list of expression nodes as arguments.
|
|
|
+ """
|
|
|
+ return scope[0] if len(scope) == 1 \
|
|
|
+ else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
|
|
|
+
|
|
|
+
|
|
|
+def get_scope(node):
|
|
|
+ """
|
|
|
+ Find all n nodes within the n-ary scope of an operator node.
|
|
|
+ """
|
|
|
+ scope = []
|
|
|
+
|
|
|
+ for child in node:
|
|
|
+ if child.is_op(node.op):
|
|
|
+ scope += get_scope(child)
|
|
|
+ else:
|
|
|
+ scope.append(child)
|
|
|
+
|
|
|
+ return scope
|