Commit dad8585c authored by Taddeus Kroes's avatar Taddeus Kroes

Merged changes of Sander.

parent d6682c57
...@@ -49,7 +49,30 @@ OP_MAP = { ...@@ -49,7 +49,30 @@ OP_MAP = {
} }
class ExpressionNode(Node): class ExpressionBase(object):
def is_leaf(self):
return self.type != TYPE_OPERATOR
def is_power(self):
return not self.is_leaf() and self.op == OP_POW
def is_nary(self):
return not self.is_leaf() and self.op in [OP_ADD, OP_SUB, OP_MUL]
def is_identifier(self):
return self.is_leaf() and self.type & TYPE_IDENTIFIER
def is_int(self):
return self.is_leaf() and self.type & TYPE_INTEGER
def is_float(self):
return self.is_leaf() and self.type & TYPE_FLOAT
def is_numeric(self):
return self.is_leaf() and self.type & (TYPE_FLOAT | TYPE_INTEGER)
class ExpressionNode(Node, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionNode, self).__init__(*args, **kwargs) super(ExpressionNode, self).__init__(*args, **kwargs)
self.type = TYPE_OPERATOR self.type = TYPE_OPERATOR
...@@ -70,23 +93,24 @@ class ExpressionNode(Node): ...@@ -70,23 +93,24 @@ class ExpressionNode(Node):
node.parent = self.parent node.parent = self.parent
self.parent = None self.parent = None
def is_power(self):
return self.op == OP_POW
def is_nary(self):
return self.op in [OP_ADD, OP_SUB, OP_MUL]
def get_order(self): def get_order(self):
if self.is_power() and self[0].is_identifier() \ if self.is_power() and self[0].is_identifier() and self[1].is_leaf():
and isinstance(self[1], Leaf): # a ^ 3
return (self[0].value, self[1].value, 1) return (self[0].value, self[1].value, 1)
if not self.op & OP_MUL:
return
for n0, n1 in [(0, 1), (1, 0)]: for n0, n1 in [(0, 1), (1, 0)]:
if self[n0].is_numeric() and not isinstance(self[n1], Leaf) \ if self[n0].is_numeric():
and self[n1].is_power(): if self[n1].is_identifier():
# 2 * a
return (self[n1].value, 1, self[n0].value)
elif self[n1].is_power():
# 2 * a ^ 3
coeff, power = self coeff, power = self
if power[0].is_identifier() and isinstance(power[1], Leaf): if power[0].is_identifier() and power[1].is_leaf():
return (power[0].value, power[1].value, coeff.value) return (power[0].value, power[1].value, coeff.value)
def get_scope(self): def get_scope(self):
...@@ -100,8 +124,11 @@ class ExpressionNode(Node): ...@@ -100,8 +124,11 @@ class ExpressionNode(Node):
return scope return scope
def get_scope_except(self, *args):
return list(set(self.get_scope()) - set(args))
class ExpressionLeaf(Leaf): class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs) super(ExpressionLeaf, self).__init__(*args, **kwargs)
...@@ -120,18 +147,6 @@ class ExpressionLeaf(Leaf): ...@@ -120,18 +147,6 @@ class ExpressionLeaf(Leaf):
node.parent = self.parent node.parent = self.parent
self.parent = None self.parent = None
def is_identifier(self):
return self.type & TYPE_IDENTIFIER
def is_int(self):
return self.type & TYPE_INTEGER
def is_float(self):
return self.type & TYPE_FLOAT
def is_numeric(self):
return self.type & (TYPE_FLOAT | TYPE_INTEGER)
if __name__ == '__main__': # pragma: nocover if __name__ == '__main__': # pragma: nocover
l0 = ExpressionLeaf(3) l0 = ExpressionLeaf(3)
......
...@@ -19,8 +19,9 @@ sys.path.insert(1, EXTERNAL_MODS) ...@@ -19,8 +19,9 @@ sys.path.insert(1, EXTERNAL_MODS)
from pybison import BisonParser, BisonSyntaxError from pybison import BisonParser, BisonSyntaxError
from graph_drawing.graph import generate_graph from graph_drawing.graph import generate_graph
from node import TYPE_OPERATOR, OP_ADD, OP_MUL, OP_SUB from node import TYPE_OPERATOR
from rules import RULES from rules import RULES
from possibilities import filter_duplicates
## Check for n-ary operator in child nodes ## Check for n-ary operator in child nodes
...@@ -90,7 +91,10 @@ class Parser(BisonParser): ...@@ -90,7 +91,10 @@ class Parser(BisonParser):
def hook_read_before(self): def hook_read_before(self):
if self.interactive and self.possibilities: if self.interactive and self.possibilities:
print 'possibilities:' print 'possibilities:'
print self.possibilities items = filter_duplicates(self.possibilities)
print ' ' + '\n '.join(map(str, items))
self.possibilities = []
def hook_read_after(self, data): def hook_read_after(self, data):
""" """
...@@ -152,16 +156,17 @@ class Parser(BisonParser): ...@@ -152,16 +156,17 @@ class Parser(BisonParser):
return data return data
def hook_handler(self, target, option, names, values, retval): def hook_handler(self, target, option, names, values, retval):
if not retval or retval.type not in RULES: if target in ['exp', 'line', 'input'] or not retval \
or retval.type != TYPE_OPERATOR or retval.op not in RULES:
return retval return retval
for handler in RULES[retval.type]: for handler in RULES[retval.op]:
self.possibilities.extend(handler(retval)) self.possibilities.extend(handler(retval))
return retval return retval
def hook_run(self, filename, retval): #def hook_run(self, filename, retval):
return retval # return retval
# --------------------------------------------------------------- # ---------------------------------------------------------------
# These methods are the python handlers for the bison targets. # These methods are the python handlers for the bison targets.
...@@ -398,5 +403,4 @@ def main(): ...@@ -398,5 +403,4 @@ def main():
if interactive: if interactive:
print print
return node return node
...@@ -3,3 +3,10 @@ class Possibility(object): ...@@ -3,3 +3,10 @@ class Possibility(object):
self.root = root self.root = root
self.handler = handler self.handler = handler
self.args = args self.args = args
def __str__(self):
return '<possibility root="%s" handler=%s args=%s>' \
% (self.root, self.handler, self.args)
def __repr__(self):
return str(self)
from itertools import combinations from itertools import combinations
from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD from ..node import ExpressionLeaf as Leaf, TYPE_OPERATOR, OP_ADD, OP_MUL
from ..possibilities import Possibility as P from ..possibilities import Possibility as P
...@@ -8,8 +8,8 @@ def match_expand(node): ...@@ -8,8 +8,8 @@ def match_expand(node):
""" """
a * (b + c) -> ab + ac a * (b + c) -> ab + ac
""" """
if node.type != TYPE_OPERATOR or not node.op & OP_MUL: assert node.type == TYPE_OPERATOR
return [] assert node.op & OP_MUL
p = [] p = []
...@@ -33,6 +33,7 @@ def match_expand(node): ...@@ -33,6 +33,7 @@ def match_expand(node):
return p return p
def expand_single(root, args): def expand_single(root, args):
""" """
Combine a leaf (left) multiplied with an addition of two expressions Combine a leaf (left) multiplied with an addition of two expressions
...@@ -41,15 +42,25 @@ def expand_single(root, args): ...@@ -41,15 +42,25 @@ def expand_single(root, args):
>>> a * (b + c) -> a * b + a * c >>> a * (b + c) -> a * b + a * c
""" """
left, right = args left, right = args
others = list(set(root.get_scope()) - {left, right}) scope = root.get_scope_except(right)
replacement = Node('+', Node('*', left, right[0]), Node('*', left, right[1]))
for i, n in enumerate(scope):
if n == left:
scope[i] = replacement
break
return nary_node('*', scope)
def match_combine_factors(node): def match_combine_factors(node):
""" """
n + exp + m -> exp + (n + m) n + exp + m -> exp + (n + m)
k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n k0 * v ^ n + exp + k1 * v ^ n -> exp + (k0 + k1) * v ^ n
""" """
if node.type != TYPE_OPERATOR or not node.op & OP_ADD: assert node.type == TYPE_OPERATOR
return [] assert node.op & OP_ADD
p = [] p = []
...@@ -62,17 +73,18 @@ def match_combine_factors(node): ...@@ -62,17 +73,18 @@ def match_combine_factors(node):
orders = [] orders = []
for n in node.get_scope(): for n in node.get_scope():
if node.type == TYPE_OPERATOR: if n.type == TYPE_OPERATOR:
order = n.get_order() order = n.get_order()
if order: if order:
orders += order orders.append(order)
else: else:
if n.is_numeric(): if n.is_numeric():
numerics.append(n) numerics.append(n)
elif n.is_identifier(): elif n.is_identifier():
orders.append((n.value, 1, 1)) orders.append((n.value, 1, 1))
print 'numerics:', numerics
if len(numerics) > 1: if len(numerics) > 1:
for num0, num1 in combinations(numerics, 2): for num0, num1 in combinations(numerics, 2):
p.append(P(node, combine_numerics, (num0, num1))) p.append(P(node, combine_numerics, (num0, num1)))
......
import unittest import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L from src.node import ExpressionNode as N, ExpressionLeaf as L
from src.rules.poly import match_combine_factors, combine_numerics, \
combine_orders
from src.rules.utils import nary_node from src.rules.utils import nary_node
from src.possibilities import Possibility as P
class TestRules(unittest.TestCase): class TestRules(unittest.TestCase):
...@@ -18,23 +15,3 @@ class TestRules(unittest.TestCase): ...@@ -18,23 +15,3 @@ class TestRules(unittest.TestCase):
l0, l1, l2 = L(1), L(2), L(3) l0, l1, l2 = L(1), L(2), L(3)
plus = N('+', N('+', l0, l1), l2) plus = N('+', N('+', l0, l1), l2)
self.assertEqual(nary_node('+', [l0, l1, l2]), plus) self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
def test_match_combine_factors_numeric_simple(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1))])
def test_match_combine_factors_numeric_combinations(self):
l0, l1, l2 = L(1), L(2), L(2)
plus = N('+', N('+', l0, l1), l2)
p = match_combine_factors(plus)
self.assertEqualPos(p, [P(plus, combine_numerics, (l0, l1)),
P(plus, combine_numerics, (l0, l2)),
P(plus, combine_numerics, (l1, l2))])
def assertEqualPos(self, possibilities, expected):
for p, e in zip(possibilities, expected):
self.assertEqual(p.root, e.root)
self.assertEqual(p.handler, e.handler)
self.assertEqual(p.args, e.args)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment