Commit dad8585c authored by Taddeus Kroes's avatar Taddeus Kroes

Merged changes of Sander.

parent d6682c57
......@@ -49,7 +49,30 @@ OP_MAP = {
}
class ExpressionNode(Node):
class ExpressionBase(object):
def is_leaf(self):
return self.type != TYPE_OPERATOR
def is_power(self):
return not self.is_leaf() and self.op == OP_POW
def is_nary(self):
return not self.is_leaf() and self.op in [OP_ADD, OP_SUB, OP_MUL]
def is_identifier(self):
return self.is_leaf() and self.type & TYPE_IDENTIFIER
def is_int(self):
return self.is_leaf() and self.type & TYPE_INTEGER
def is_float(self):
return self.is_leaf() and self.type & TYPE_FLOAT
def is_numeric(self):
return self.is_leaf() and self.type & (TYPE_FLOAT | TYPE_INTEGER)
class ExpressionNode(Node, ExpressionBase):
def __init__(self, *args, **kwargs):
super(ExpressionNode, self).__init__(*args, **kwargs)
self.type = TYPE_OPERATOR
......@@ -70,23 +93,24 @@ class ExpressionNode(Node):
node.parent = self.parent
self.parent = None
def is_power(self):
return self.op == OP_POW
def is_nary(self):
return self.op in [OP_ADD, OP_SUB, OP_MUL]
def get_order(self):
if self.is_power() and self[0].is_identifier() \
and isinstance(self[1], Leaf):
if self.is_power() and self[0].is_identifier() and self[1].is_leaf():
# a ^ 3
return (self[0].value, self[1].value, 1)
if not self.op & OP_MUL:
return
for n0, n1 in [(0, 1), (1, 0)]:
if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \
and self[n1].is_power():
if self[n0].is_numeric():
if self[n1].is_identifier():
# 2 * a
return (self[n1].value, 1, self[n0].value)
elif self[n1].is_power():
# 2 * a ^ 3
coeff, power = self
if power[0].is_identifier() and isinstance(power[1], Leaf):
if power[0].is_identifier() and power[1].is_leaf():
return (power[0].value, power[1].value, coeff.value)
def get_scope(self):
......@@ -100,8 +124,11 @@ class ExpressionNode(Node):
return scope
def get_scope_except(self, *args):
return list(set(self.get_scope()) - set(args))
class ExpressionLeaf(Leaf):
class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs)
......@@ -120,18 +147,6 @@ class ExpressionLeaf(Leaf):
node.parent = self.parent
self.parent = None
def is_identifier(self):
return self.type & TYPE_IDENTIFIER
def is_int(self):
return self.type & TYPE_INTEGER
def is_float(self):
return self.type & TYPE_FLOAT
def is_numeric(self):
return self.type & (TYPE_FLOAT | TYPE_INTEGER)
if __name__ == '__main__': # pragma: nocover
l0 = ExpressionLeaf(3)
......
......@@ -19,8 +19,9 @@ sys.path.insert(1, EXTERNAL_MODS)
from pybison import BisonParser, BisonSyntaxError
from graph_drawing.graph import generate_graph
from node import TYPE_OPERATOR, OP_ADD, OP_MUL, OP_SUB
from node import TYPE_OPERATOR
from rules import RULES
from possibilities import filter_duplicates
## Check for n-ary operator in child nodes
......@@ -90,7 +91,10 @@ class Parser(BisonParser):
def hook_read_before(self):
if self.interactive and self.possibilities:
print 'possibilities:'
print self.possibilities
items = filter_duplicates(self.possibilities)
print ' ' + '\n '.join(map(str, items))
self.possibilities = []
def hook_read_after(self, data):
"""
......@@ -152,16 +156,17 @@ class Parser(BisonParser):
return data
def hook_handler(self, target, option, names, values, retval):
if not retval or retval.type not in RULES:
if target in ['exp', 'line', 'input'] or not retval \
or retval.type != TYPE_OPERATOR or retval.op not in RULES:
return retval
for handler in RULES[retval.type]:
for handler in RULES[retval.op]:
self.possibilities.extend(handler(retval))
return retval
def hook_run(self, filename, retval):
return retval
#def hook_run(self, filename, retval):
# return retval
# ---------------------------------------------------------------
# These methods are the python handlers for the bison targets.
......@@ -398,5 +403,4 @@ def main():
if interactive:
print
return node
......@@ -3,3 +3,10 @@ class Possibility(object):
self.root = root
self.handler = handler
self.args = args
def __str__(self):
return '<possibility root="%s" handler=%s args=%s>' \
% (self.root, self.handler, self.args)
def __repr__(self):
return str(self)
from itertools import combinations
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD, OP_MUL
from ..possibilities import Possibility as P
......@@ -8,8 +8,8 @@ def match_expand(node):
"""
a * (b + c) -> ab + ac
"""
if node.type != TYPE_OPERATOR or not node.op & OP_MUL:
return []
assert node.type == TYPE_OPERATOR
assert node.op & OP_MUL
p = []
......@@ -33,6 +33,7 @@ def match_expand(node):
return p
def expand_single(root, args):
"""
Combine a leaf (left) multiplied with an addition of two expressions
......@@ -41,15 +42,25 @@ def expand_single(root, args):
>>> a * (b + c) -> a * b + a * c
"""
left, right = args
others = list(set(root.get_scope()) - {left, right})
scope = root.get_scope_except(right)
replacement = Node('+', Node('*', left, right[0]), Node('*', left, right[1]))
for i, n in enumerate(scope):
if n == left:
scope[i] = replacement
break
return nary_node('*', scope)
def match_combine_factors(node):
"""
n + exp + m -> exp + (n + m)
k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
"""
if node.type != TYPE_OPERATOR or not node.op & OP_ADD:
return []
assert node.type == TYPE_OPERATOR
assert node.op & OP_ADD
p = []
......@@ -62,17 +73,18 @@ def match_combine_factors(node):
orders = []
for n in node.get_scope():
if node.type == TYPE_OPERATOR:
if n.type == TYPE_OPERATOR:
order = n.get_order()
if order:
orders += order
orders.append(order)
else:
if n.is_numeric():
numerics.append(n)
elif n.is_identifier():
orders.append((n.value, 1, 1))
print 'numerics:', numerics
if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1)))
......
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L
from src.rules.poly import match_combine_factors, combine_numerics, \
combine_orders
from src.rules.utils import nary_node
from src.possibilities import Possibility as P
class TestRules(unittest.TestCase):
......@@ -18,23 +15,3 @@ class TestRules(unittest.TestCase):
l0, l1, l2 = L(1), L(2), L(3)
plus = N('+', N('+', l0, l1), l2)
self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
def test_match_combine_factors_numeric_simple(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
def test_match_combine_factors_numeric_combinations(self):
l0, l1, l2 = L(1), L(2), L(2)
plus = N('+', N('+', l0, l1), l2)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
P(plus, combine_numerics, (l0, l2)),
P(plus, combine_numerics, (l1, l2))])
def assertEqualPos(self, possibilities, expected):
for p, e in zip(possibilities, expected):
self.assertEqual(p.root, e.root)
self.assertEqual(p.handler, e.handler)
self.assertEqual(p.args, e.args)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment