Commit 513c04f3 authored by Taddeus Kroes's avatar Taddeus Kroes

Implemented some basis rewrite rules, along with unit tests for node.py.

parent 054c9265
...@@ -8,27 +8,110 @@ from graph_drawing.line import generate_line ...@@ -8,27 +8,110 @@ from graph_drawing.line import generate_line
from graph_drawing.node import Node, Leaf from graph_drawing.node import Node, Leaf
#NODE_TYPE = 0 TYPE_OPERATOR = 1
#NODE_ TYPE_IDENTIFIER = 2
TYPE_INTEGER = 4
TYPE_FLOAT = 8
TYPE_NUMERIC = TYPE_INTEGER | TYPE_FLOAT
# Unary
OP_NEG = 1
# Binary
OP_ADD = 2
OP_SUB = 3
OP_MUL = 4
OP_DIV = 5
OP_POW = 6
OP_MOD = 7
# N-ary (functions)
OP_INT = 8
OP_EXPAND = 9
TYPE_MAP = {
int: TYPE_INTEGER,
float: TYPE_FLOAT,
str: TYPE_IDENTIFIER,
}
OPT_MAP = {
'+': OP_ADD,
'-': OP_SUB,
'*': OP_MUL,
'/': OP_DIV,
'^': OP_POW,
'mod': OP_MOD,
'int': OP_INT,
'expand': OP_EXPAND,
}
class ExpressionNode(Node): class ExpressionNode(Node):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionNode, self).__init__(*args, **kwargs) super(ExpressionNode, self).__init__(*args, **kwargs)
#self.type = NODE_TYPE self.type = TYPE_OPERATOR
self.opt = OPT_MAP[args[0]]
def __str__(self): def __str__(self): # pragma: nocover
return generate_line(self) return generate_line(self)
def graph(self): # pragma: nocover
return generate_graph(self)
def replace(self, node): def replace(self, node):
pos = self.parent.nodes.index(self) pos = self.parent.nodes.index(self)
self.parent.nodes[pos] = node self.parent.nodes[pos] = node
node.parent = self.parent node.parent = self.parent
self.parent = None self.parent = None
def graph(self): def is_power(self):
return generate_graph(self) return self.opt == OP_POW
def is_nary(self):
return self.opt in [OP_ADD, OP_SUB, OP_MUL]
def get_order(self):
if self.is_power() and self[0].is_identifier() \
and isinstance(self[1], Leaf):
return (self[0].value, self[1].value, 1)
for n0, n1 in [(0, 1), (1, 0)]:
if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \
and self[n1].is_power():
coeff, power = self
if power[0].is_identifier() and isinstance(power[1], Leaf):
return (power[0].value, power[1].value, coeff.value)
def get_scope(self):
scope = []
for child in self:
if not isinstance(child, Leaf) and child.opt == self.opt:
scope += child.get_scope()
else:
scope.append(child)
return scope
class ExpressionLeaf(Leaf): class ExpressionLeaf(Leaf):
def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs)
for data_type, type_repr in TYPE_MAP.iteritems():
if isinstance(args[0], data_type):
self.type = type_repr
break
def get_order(self):
if self.is_identifier():
return (self.value, 1, 1)
def replace(self, node): def replace(self, node):
if not hasattr(self, 'parent'): if not hasattr(self, 'parent'):
return return
...@@ -38,8 +121,20 @@ class ExpressionLeaf(Leaf): ...@@ -38,8 +121,20 @@ class ExpressionLeaf(Leaf):
node.parent = self.parent node.parent = self.parent
self.parent = None self.parent = None
def is_identifier(self):
return self.type & TYPE_IDENTIFIER
def is_int(self):
return self.type & TYPE_INTEGER
def is_float(self):
return self.type & TYPE_FLOAT
def is_numeric(self):
return self.type & TYPE_NUMERIC
if __name__ == '__main__': if __name__ == '__main__': # pragma: nocover
l0 = ExpressionLeaf(3) l0 = ExpressionLeaf(3)
l1 = ExpressionLeaf(4) l1 = ExpressionLeaf(4)
l2 = ExpressionLeaf(5) l2 = ExpressionLeaf(5)
...@@ -67,8 +162,8 @@ if __name__ == '__main__': ...@@ -67,8 +162,8 @@ if __name__ == '__main__':
return res return res
possibilities = [ possibilities = [
(n0, lambda (x,y): ExpressionLeaf(x.value + y.value)), (n0, lambda (x, y): ExpressionLeaf(x.value + y.value)),
(n1, lambda (x,y): ExpressionLeaf(x.value + y.value)), (n1, lambda (x, y): ExpressionLeaf(x.value + y.value)),
(n2, rewrite_multiply), (n2, rewrite_multiply),
] ]
......
class Possibility(object):
def __init__(self, root, handler, args):
self.root = root
self.handler = handler
self.args = args
from node import ExpressionLeaf as Leaf from itertools import combinations
def get_factor_constants(operand): from node import ExpressionNode as Node, ExpressionLeaf as Leaf
op = operand.title() from possibilities import Possibility as P
res = []
if operand.type == OP_MUL:
if operand[0].type == LEAF_NUM:
fn()
if operand[1].type == LEAF_NUM: def match_combine_factors(node):
res += operand[1] """
n + exp + m -> exp + (n + m)
k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
"""
p = []
return res if node.is_nary():
# Collect all nodes that can be combined
# Numeric leaves
numerics = []
# Identifier leaves of all orders, tuple format is;
# (identifier, exponent, coefficient)
orders = []
# Nodes that cannot be combined
others = []
for n in node.get_scope():
if isinstance(n, Leaf):
if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
else:
order = n.get_order()
if order:
orders += order
else:
others.append(n)
if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1, others)))
if len(orders) > 1:
for order0, order1 in combinations(orders, 2):
id0, exponent0, coeff0 = order0
id1, exponent1, coeff1 = order1
if id0 == id1 and exponent0 == exponent1:
# Same identifier and exponent -> combine coefficients
args = order0 + (coeff1,) + (others,)
p.append(P(node, combine_orders, args))
def combine_plus_factors(node): return p
p = []
# Check if any numeric factors can be combined
def apply_numeric_factors(node, leaves):
return Leaf(reduce(lambda a, b: a.value + b.value, leaves))
num_nodes = [] def combine_numerics(root, args):
"""
Combine two numeric leaves in an n-ary plus.
for n in node: Example:
# NUM + NUM -> NUM 3 + 4 -> 7
if n.type == VAL_NUM: """
num_nodes.append(n) numerics, others = args
value = sum([n.value for n in numerics])
if len(num_nodes) > 1: return nary_node('+', others + [Leaf(value)])
p.append((node, apply_plus_factors, num_nodes))
# Check if any variable multiplcations/divisions can be combined
def apply_identifiers(node, operands):
apply_constant(lambda x: )
return Leaf(leaves[0].value + leaves[1].value)
id_nodes = [] def combine_orders(root, args):
"""
Combine two identifier multiplications of any order in an n-ary plus.
for n in node: Example:
# NUM * + NUM -> NUM 3x + 4x -> 7x
if n.type == OP_MUL: """
consts = get_factor_constants(n) identifier, exponent, coeff0, coeff1, others = args
if len(consts) > 1: coeff = coeff0 + coeff1
id_nodes +=
if len(num_nodes) > 1: if not exponent:
p.append((node, apply_plus_factors, num_nodes)) # a ^ 0 -> 1
ident = Leaf(1)
elif exponent == 1:
# a ^ 1 -> a
ident = Leaf(identifier)
else:
# a ^ n -> a ^ n
ident = Node('^', Leaf(identifier), Leaf(exponent))
if coeff == 1:
combined = ident
else:
combined = Node('*', Leaf(coeff), ident)
return nary_node('+', others + [combined])
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
return scope[0] if len(scope) == 1 \
else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
return p
rules = { rules = {
'+': [combine_plus_factors], '+': [match_combine_factors],
} }
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L
class TestNode(unittest.TestCase):
def setUp(self):
self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
def test_replace_node(self):
inner = N('+', L(1), L(2))
node = N('+', inner, L(3))
replacement = N('-', L(4), L(5))
inner.replace(replacement)
self.assertEqual(str(node), '4 - 5 + 3')
def test_replace_leaf(self):
inner = N('+', L(1), L(2))
node = N('+', inner, L(3))
replacement = L(4)
inner.replace(replacement)
self.assertEqual(str(node), '4 + 3')
def test_is_power_true(self):
self.assertTrue(N('^', *self.l[:2]).is_power())
self.assertFalse(N('+', *self.l[:2]).is_power())
def test_is_nary(self):
self.assertTrue(N('+', *self.l[:2]).is_nary())
self.assertTrue(N('-', *self.l[:2]).is_nary())
self.assertTrue(N('*', *self.l[:2]).is_nary())
self.assertFalse(N('^', *self.l[:2]).is_nary())
def test_is_identifier(self):
self.assertTrue(L('a').is_identifier())
self.assertFalse(L(1).is_identifier())
def test_is_int(self):
self.assertTrue(L(1).is_int())
self.assertFalse(L(1.5).is_int())
self.assertFalse(L('a').is_int())
def test_is_float(self):
self.assertTrue(L(1.5).is_float())
self.assertFalse(L(1).is_float())
self.assertFalse(L('a').is_float())
def test_is_numeric(self):
self.assertTrue(L(1).is_numeric())
self.assertTrue(L(1.5).is_numeric())
self.assertFalse(L('a').is_numeric())
def test_get_order_identifier(self):
self.assertEqual(L('a').get_order(), ('a', 1, 1))
def test_get_order_None(self):
self.assertIsNone(L(1).get_order())
def test_get_order_power(self):
power = N('^', L('a'), L(2))
self.assertEqual(power.get_order(), ('a', 2, 1))
def test_get_order_coefficient_exponent_int(self):
times = N('*', L(3), N('^', L('a'), L(2)))
self.assertEqual(times.get_order(), ('a', 2, 3))
def test_get_order_coefficient_exponent_id(self):
times = N('*', L(3), N('^', L('a'), L('b')))
self.assertEqual(times.get_order(), ('a', 'b', 3))
def test_get_scope_binary(self):
plus = N('+', *self.l[:2])
self.assertEqual(plus.get_scope(), self.l[:2])
def test_get_scope_nested_left(self):
plus = N('+', N('+', *self.l[:2]), self.l[2])
self.assertEqual(plus.get_scope(), self.l[:3])
def test_get_scope_nested_right(self):
plus = N('+', self.l[0], N('+', *self.l[1:3]))
self.assertEqual(plus.get_scope(), self.l[:3])
def test_get_scope_nested_deep(self):
plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
self.assertEqual(plus.get_scope(), self.l)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L
from src.rules import match_combine_factors, combine_numerics, \
combine_orders, nary_node
from src.possibilities import Possibility as P
class TestRules(unittest.TestCase):
def test_nary_node_binary(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
self.assertEqual(nary_node('+', [l0, l1]), plus)
def test_nary_node_ternary(self):
l0, l1, l2 = L(1), L(2), L(3)
plus = N('+', N('+', l0, l1), l2)
self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
def test_match_combine_factors_numeric_simple(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, []))])
def test_match_combine_factors_numeric_combinations(self):
l0, l1, l2 = L(1), L(2), L(2)
plus = N('+', N('+', l0, l1), l2)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1, [])),
P(plus, combine_numerics, (l0, l2, [])),
P(plus, combine_numerics, (l1, l2, []))])
def assertEqualPos(self, possibilities, expected):
for p, e in zip(possibilities, expected):
self.assertEqual(p.root, e.root)
self.assertEqual(p.handler, e.handler)
self.assertEqual(p.args, e.args)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment