Skip to content
Snippets Groups Projects
Commit d1eb149c authored by Sander Mathijs van Veen's avatar Sander Mathijs van Veen
Browse files

Fixed merge conflict.

parents 9b18453b c6147911
No related branches found
No related tags found
No related merge requests found
......@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
return (self[0], self[1], ExpressionLeaf(1))
return (self[1], self[0], ExpressionLeaf(1))
def get_scope(self):
"""
Find all n nodes within the n-ary scope of this operator.
"""
scope = []
#op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
# TODO: what to do with OP_SUB and OP_ADD in get_scope?
for child in self:
if not child.is_leaf() and child.op == self.op:
scope += child.get_scope()
else:
scope.append(child)
return scope
def equals(self, other):
"""
Perform a non-strict equivalence check between two nodes:
......@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
return False
if self.op in (OP_ADD, OP_MUL):
s0 = self.get_scope()
s1 = set(other.get_scope())
s0 = Scope(self)
s1 = set(Scope(other))
# Scopes sould be of equal size
if len(s0) != len(s1):
......@@ -354,3 +337,74 @@ class ExpressionLeaf(Leaf, ExpressionBase):
"""
# rule: 1 * r ^ 1 -> (1, r, 1)
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = get_scope(node)
def __getitem__(self, key):
return self.nodes[key]
def __setitem__(self, key, value):
self.nodes[key] = value
def __len__(self):
return len(self.nodes)
def __iter__(self):
return iter(self.nodes)
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
if len(scope) == 1:
return scope[0]
return ExpressionNode(operator, nary_node(operator, scope[:-1]), scope[-1])
def get_scope(node):
"""
Find all n nodes within the n-ary scope of an operator node.
"""
scope = []
for child in node:
if child.is_op(node.op):
scope += get_scope(child)
else:
scope.append(child)
return scope
from itertools import product, combinations
from .utils import nary_node
from ..node import OP_ADD, OP_MUL, OP_NEG
from ..node import Scope, OP_ADD, OP_MUL, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -18,7 +17,7 @@ def match_expand(node):
leaves = []
additions = []
for n in node.get_scope():
for n in Scope(node):
if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
leaves.append(n)
elif n.op == OP_ADD:
......@@ -43,15 +42,15 @@ def expand_single(root, args):
"""
a, bc = args
b, c = bc
scope = root.get_scope()
scope = Scope(root)
# Replace 'a' with the new expression
scope[scope.index(a)] = a * b + a * c
scope.remove(a, a * b + a * c)
# Remove the addition
scope.remove(bc)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[expand_single] = _('Expand {1}({2}) to {1}({2[0]}) + {1}({2[1]}).')
......@@ -64,15 +63,15 @@ def expand_double(root, args):
(a + b) * (c + d) -> ac + ad + bc + bd
"""
(a, b), (c, d) = ab, cd = args
scope = root.get_scope()
scope = Scope(root)
# Replace 'b + c' with the new expression
scope[scope.index(ab)] = a * c + a * d + b * c + b * d
# Replace 'a + b' with the new expression
scope.remove(ab, a * c + a * d + b * c + b * d)
# Remove the right addition
scope.remove(cd)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[expand_double] = _('Expand ({1})({2}) to {1[0]}{2[0]} + {1[0]}{2[1]}'
......
from itertools import combinations
from .utils import nary_node, least_common_multiple
from ..node import ExpressionLeaf as L, OP_DIV, OP_ADD, OP_MUL, OP_NEG
from .utils import least_common_multiple
from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -84,7 +84,7 @@ def match_add_constant_fractions(node):
return node.is_op(OP_DIV) or \
(node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
fractions = filter(is_division, node.get_scope())
fractions = filter(is_division, Scope(node))
for a, b in combinations(fractions, 2):
if a.is_op(OP_NEG):
......@@ -117,7 +117,7 @@ def equalize_denominators(root, args):
"""
denom = args[2]
scope = root.get_scope()
scope = Scope(root)
for fraction in args[:2]:
n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction
......@@ -127,11 +127,11 @@ def equalize_denominators(root, args):
n = L(n.value * mult) if n.is_numeric() else L(mult) * n
if fraction.is_op(OP_NEG):
scope[scope.index(fraction)] = -(n / L(d.value * mult))
scope.remove(fraction, -(n / L(d.value * mult)))
else:
scope[scope.index(fraction)] = n / L(d.value * mult)
scope.remove(fraction, n / L(d.value * mult))
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[equalize_denominators] = _('Equalize the denominators of division'
......@@ -157,17 +157,15 @@ def add_nominators(root, args):
else:
c = cb[0]
substitution = (a + c) / b
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(ab)] = substitution
scope.remove(ab, (a + c) / b)
# Remove the right node
scope.remove(cb)
return nary_node('+', scope)
return scope.as_nary_node()
# TODO: convert this to a lambda. Example: 22 / 77 - 28 / 77. the "-" is above
......
from itertools import combinations
from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
ExpressionLeaf as Leaf
from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
from .utils import nary_node
def match_combine_groups(node):
......@@ -25,13 +24,13 @@ def match_combine_groups(node):
p = []
groups = []
for n in node.get_scope():
for n in Scope(node):
groups.append((1, n, n))
# Each number multiplication yields a group, multiple occurences of
# the same group can be replaced by a single one
if n.is_op(OP_MUL):
scope = n.get_scope()
scope = Scope(n)
l = len(scope)
for i, sub_node in enumerate(scope):
......@@ -55,18 +54,18 @@ def match_combine_groups(node):
def combine_groups(root, args):
c0, g0, n0, c1, g1, n1 = args
scope = root.get_scope()
scope = Scope(root)
if not isinstance(c0, Leaf):
c0 = Leaf(c0)
# Replace the left node with the new expression
scope[scope.index(n0)] = (c0 + c1) * g0
scope.remove(n0, (c0 + c1) * g0)
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[combine_groups] = \
......
from itertools import combinations
from .utils import nary_node
from ..node import ExpressionLeaf as Leaf, OP_DIV, OP_MUL, OP_NEG
from ..node import ExpressionLeaf as Leaf, Scope, nary_node, OP_DIV, OP_MUL, \
OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -28,15 +28,15 @@ def add_numerics(root, args):
else:
c1 = c1.value
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = Leaf(c0 + c1)
scope.remove(n0, Leaf(c0 + c1))
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[add_numerics] = _('Combine the constants {1} and {2}, which'
......@@ -119,7 +119,7 @@ def match_multiply_numerics(node):
p = []
numerics = []
for n in node.get_scope():
for n in Scope(node):
if n.is_numeric():
numerics.append((n, n.value))
elif n.is_op(OP_NEG) and n[0].is_numeric():
......@@ -147,7 +147,7 @@ def multiply_numerics(root, args):
else:
substitution = -Leaf(-value)
for n in root.get_scope():
for n in Scope(root):
if hash(n) == hash(n0):
# Replace the left node with the new expression
scope.append(substitution)
......
from itertools import combinations
from ..node import OP_ADD, OP_NEG
from ..node import Scope, OP_ADD, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
from .numerics import add_numerics
......@@ -32,7 +31,7 @@ def match_combine_polynomes(node, verbose=False):
if verbose: # pragma: nocover
print 'match combine factors:', node
for n in node.get_scope():
for n in Scope(node):
polynome = n.extract_polynome_properties()
if verbose: # pragma: nocover
......@@ -84,16 +83,14 @@ def combine_polynomes(root, args):
else:
power = r ** e
# replacement: (c0 + c1) * a ^ b
# a, b and c are from 'left', d is from 'right'.
replacement = (c0 + c1) * power
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = replacement
# Replace the left node with the new expression:
# (c0 + c1) * a ^ b
# a, b and c are from 'left', d is from 'right'.
scope.remove(n0, (c0 + c1) * power)
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
from itertools import combinations
from ..node import ExpressionNode as N, ExpressionLeaf as L, \
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
from ..translate import _
......@@ -19,7 +18,7 @@ def match_add_exponents(node):
p = []
powers = {}
for n in node.get_scope():
for n in Scope(node):
if n.is_identifier():
s = n
exponent = L(1)
......@@ -52,15 +51,15 @@ def add_exponents(root, args):
a^p * a^q -> a^(p + q)
"""
n0, n1, a, p, q = args
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = a ** (p + q)
scope.remove(n0, a ** (p + q))
# Remove the right node
scope.remove(n1)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
......@@ -116,7 +115,7 @@ def match_duplicate_exponent(node):
left, right = node
if left.is_op(OP_MUL):
return [P(node, duplicate_exponent, (left.get_scope(), right))]
return [P(node, duplicate_exponent, (list(Scope(left)), right))]
return []
......@@ -159,7 +158,7 @@ def match_extend_exponent(node):
left, right = node
if right.is_numeric():
for n in node.get_scope():
for n in Scope(node):
if n.is_op(OP_ADD):
return [P(node, extend_exponent, (left, right))]
......
from ..node import ExpressionNode as Node
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
return scope[0] if len(scope) == 1 \
else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
def gcd(a, b):
"""
Return greatest common divisor using Euclid's Algorithm.
......
from rules.utils import nary_node
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = node.get_scope()
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self.nodes[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
nary_node, get_scope, OP_ADD
from tests.rulestestcase import RulesTestCase, tree
from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
from tests.rulestestcase import tree
class TestNode(unittest.TestCase):
class TestNode(RulesTestCase):
def setUp(self):
self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___lt__(self):
self.assertTrue(L(1) < L(2))
......@@ -95,19 +98,19 @@ class TestNode(unittest.TestCase):
def test_get_scope_binary(self):
plus = N('+', *self.l[:2])
self.assertEqual(plus.get_scope(), self.l[:2])
self.assertEqual(get_scope(plus), self.l[:2])
def test_get_scope_nested_left(self):
plus = N('+', N('+', *self.l[:2]), self.l[2])
self.assertEqual(plus.get_scope(), self.l[:3])
self.assertEqual(get_scope(plus), self.l[:3])
def test_get_scope_nested_right(self):
plus = N('+', self.l[0], N('+', *self.l[1:3]))
self.assertEqual(plus.get_scope(), self.l[:3])
self.assertEqual(get_scope(plus), self.l[:3])
def test_get_scope_nested_deep(self):
plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
self.assertEqual(plus.get_scope(), self.l)
self.assertEqual(get_scope(plus), self.l)
def test_equals_node_leaf(self):
a, b = plus = tree('a + b')
......@@ -168,3 +171,36 @@ class TestNode(unittest.TestCase):
m0, m1 = tree('-5 * -3,-5 * 6')
self.assertFalse(m0.equals(m1))
def test_scope___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_scope_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_scope_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_scope_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_scope_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_nary_node(self):
a, b, c, d = tree('a,b,c,d')
self.assertEqualNodes(nary_node('+', [a]), a)
self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
self.assertEqualNodes(nary_node('+', [a, b, c]),
N('+', N('+', a, b), c))
self.assertEqualNodes(nary_node('+', [a, b, c, d]),
N('+', N('+', N('+', a, b), c), d))
def test_scope_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L
from src.rules.utils import nary_node
class TestRules(unittest.TestCase):
def test_nary_node_binary(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
self.assertEqual(nary_node('+', [l0, l1]), plus)
def test_nary_node_ternary(self):
l0, l1, l2 = L(1), L(2), L(3)
plus = N('+', N('+', l0, l1), l2)
self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
from src.node import ExpressionNode as N
from src.rules.utils import nary_node, least_common_multiple
from tests.rulestestcase import RulesTestCase, tree
import unittest
from src.rules.utils import least_common_multiple
class TestRulesUtils(RulesTestCase):
def test_nary_node(self):
a, b, c, d = tree('a,b,c,d')
self.assertEqualNodes(nary_node('+', [a]), a)
self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
self.assertEqualNodes(nary_node('+', [a, b, c]),
N('+', N('+', a, b), c))
self.assertEqualNodes(nary_node('+', [a, b, c, d]),
N('+', N('+', N('+', a, b), c), d))
class TestRulesUtils(unittest.TestCase):
def test_least_common_multiple(self):
self.assertEqual(least_common_multiple(5, 6), 30)
......
import unittest
from src.scope import Scope
from tests.rulestestcase import RulesTestCase, tree
class TestScope(RulesTestCase):
def setUp(self):
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment