Commit d1eb149c authored by Sander Mathijs van Veen's avatar Sander Mathijs van Veen

Fixed merge conflict.

parents 9b18453b c6147911
......@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
return (self[0], self[1], ExpressionLeaf(1))
return (self[1], self[0], ExpressionLeaf(1))
def get_scope(self):
"""
Find all n nodes within the n-ary scope of this operator.
"""
scope = []
#op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
# TODO: what to do with OP_SUB and OP_ADD in get_scope?
for child in self:
if not child.is_leaf() and child.op == self.op:
scope += child.get_scope()
else:
scope.append(child)
return scope
def equals(self, other):
"""
Perform a non-strict equivalence check between two nodes:
......@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
return False
if self.op in (OP_ADD, OP_MUL):
s0 = self.get_scope()
s1 = set(other.get_scope())
s0 = Scope(self)
s1 = set(Scope(other))
# Scopes sould be of equal size
if len(s0) != len(s1):
......@@ -354,3 +337,74 @@ class ExpressionLeaf(Leaf, ExpressionBase):
"""
# rule: 1 * r ^ 1 -> (1, r, 1)
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = get_scope(node)
def __getitem__(self, key):
return self.nodes[key]
def __setitem__(self, key, value):
self.nodes[key] = value
def __len__(self):
return len(self.nodes)
def __iter__(self):
return iter(self.nodes)
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
if len(scope) == 1:
return scope[0]
return ExpressionNode(operator, nary_node(operator, scope[:-1]), scope[-1])
def get_scope(node):
"""
Find all n nodes within the n-ary scope of an operator node.
"""
scope = []
for child in node:
if child.is_op(node.op):
scope += get_scope(child)
else:
scope.append(child)
return scope
from itertools import product, combinations
from .utils import nary_node
from ..node import OP_ADD, OP_MUL, OP_NEG
from ..node import Scope, OP_ADD, OP_MUL, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -18,7 +17,7 @@ def match_expand(node):
leaves = []
additions = []
for n in node.get_scope():
for n in Scope(node):
if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
leaves.append(n)
elif n.op == OP_ADD:
......@@ -43,15 +42,15 @@ def expand_single(root, args):
"""
a, bc = args
b, c = bc
scope = root.get_scope()
scope = Scope(root)
# Replace 'a' with the new expression
scope[scope.index(a)] = a * b + a * c
scope.remove(a, a * b + a * c)
# Remove the addition
scope.remove(bc)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[expand_single] = _('Expand {1}({2}) to {1}({2[0]}) + {1}({2[1]}).')
......@@ -64,15 +63,15 @@ def expand_double(root, args):
(a + b) * (c + d) -> ac + ad + bc + bd
"""
(a, b), (c, d) = ab, cd = args
scope = root.get_scope()
scope = Scope(root)
# Replace 'b + c' with the new expression
scope[scope.index(ab)] = a * c + a * d + b * c + b * d
# Replace 'a + b' with the new expression
scope.remove(ab, a * c + a * d + b * c + b * d)
# Remove the right addition
scope.remove(cd)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[expand_double] = _('Expand ({1})({2}) to {1[0]}{2[0]} + {1[0]}{2[1]}'
......
from itertools import combinations
from .utils import nary_node, least_common_multiple
from ..node import ExpressionLeaf as L, OP_DIV, OP_ADD, OP_MUL, OP_NEG
from .utils import least_common_multiple
from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -84,7 +84,7 @@ def match_add_constant_fractions(node):
return node.is_op(OP_DIV) or \
(node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
fractions = filter(is_division, node.get_scope())
fractions = filter(is_division, Scope(node))
for a, b in combinations(fractions, 2):
if a.is_op(OP_NEG):
......@@ -117,7 +117,7 @@ def equalize_denominators(root, args):
"""
denom = args[2]
scope = root.get_scope()
scope = Scope(root)
for fraction in args[:2]:
n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction
......@@ -127,11 +127,11 @@ def equalize_denominators(root, args):
n = L(n.value * mult) if n.is_numeric() else L(mult) * n
if fraction.is_op(OP_NEG):
scope[scope.index(fraction)] = -(n / L(d.value * mult))
scope.remove(fraction, -(n / L(d.value * mult)))
else:
scope[scope.index(fraction)] = n / L(d.value * mult)
scope.remove(fraction, n / L(d.value * mult))
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[equalize_denominators] = _('Equalize the denominators of division'
......@@ -157,17 +157,15 @@ def add_nominators(root, args):
else:
c = cb[0]
substitution = (a + c) / b
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(ab)] = substitution
scope.remove(ab, (a + c) / b)
# Remove the right node
scope.remove(cb)
return nary_node('+', scope)
return scope.as_nary_node()
# TODO: convert this to a lambda. Example: 22 / 77 - 28 / 77. the "-" is above
......
from itertools import combinations
from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
ExpressionLeaf as Leaf
from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
from .utils import nary_node
def match_combine_groups(node):
......@@ -25,13 +24,13 @@ def match_combine_groups(node):
p = []
groups = []
for n in node.get_scope():
for n in Scope(node):
groups.append((1, n, n))
# Each number multiplication yields a group, multiple occurences of
# the same group can be replaced by a single one
if n.is_op(OP_MUL):
scope = n.get_scope()
scope = Scope(n)
l = len(scope)
for i, sub_node in enumerate(scope):
......@@ -55,18 +54,18 @@ def match_combine_groups(node):
def combine_groups(root, args):
c0, g0, n0, c1, g1, n1 = args
scope = root.get_scope()
scope = Scope(root)
if not isinstance(c0, Leaf):
c0 = Leaf(c0)
# Replace the left node with the new expression
scope[scope.index(n0)] = (c0 + c1) * g0
scope.remove(n0, (c0 + c1) * g0)
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[combine_groups] = \
......
from itertools import combinations
from .utils import nary_node
from ..node import ExpressionLeaf as Leaf, OP_DIV, OP_MUL, OP_NEG
from ..node import ExpressionLeaf as Leaf, Scope, nary_node, OP_DIV, OP_MUL, \
OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -28,15 +28,15 @@ def add_numerics(root, args):
else:
c1 = c1.value
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = Leaf(c0 + c1)
scope.remove(n0, Leaf(c0 + c1))
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
MESSAGES[add_numerics] = _('Combine the constants {1} and {2}, which'
......@@ -119,7 +119,7 @@ def match_multiply_numerics(node):
p = []
numerics = []
for n in node.get_scope():
for n in Scope(node):
if n.is_numeric():
numerics.append((n, n.value))
elif n.is_op(OP_NEG) and n[0].is_numeric():
......@@ -147,7 +147,7 @@ def multiply_numerics(root, args):
else:
substitution = -Leaf(-value)
for n in root.get_scope():
for n in Scope(root):
if hash(n) == hash(n0):
# Replace the left node with the new expression
scope.append(substitution)
......
from itertools import combinations
from ..node import OP_ADD, OP_NEG
from ..node import Scope, OP_ADD, OP_NEG
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
from .numerics import add_numerics
......@@ -32,7 +31,7 @@ def match_combine_polynomes(node, verbose=False):
if verbose: # pragma: nocover
print 'match combine factors:', node
for n in node.get_scope():
for n in Scope(node):
polynome = n.extract_polynome_properties()
if verbose: # pragma: nocover
......@@ -84,16 +83,14 @@ def combine_polynomes(root, args):
else:
power = r ** e
# replacement: (c0 + c1) * a ^ b
# a, b and c are from 'left', d is from 'right'.
replacement = (c0 + c1) * power
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = replacement
# Replace the left node with the new expression:
# (c0 + c1) * a ^ b
# a, b and c are from 'left', d is from 'right'.
scope.remove(n0, (c0 + c1) * power)
# Remove the right node
scope.remove(n1)
return nary_node('+', scope)
return scope.as_nary_node()
from itertools import combinations
from ..node import ExpressionNode as N, ExpressionLeaf as L, \
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
from ..possibilities import Possibility as P, MESSAGES
from .utils import nary_node
from ..translate import _
......@@ -19,7 +18,7 @@ def match_add_exponents(node):
p = []
powers = {}
for n in node.get_scope():
for n in Scope(node):
if n.is_identifier():
s = n
exponent = L(1)
......@@ -52,15 +51,15 @@ def add_exponents(root, args):
a^p * a^q -> a^(p + q)
"""
n0, n1, a, p, q = args
scope = root.get_scope()
scope = Scope(root)
# Replace the left node with the new expression
scope[scope.index(n0)] = a ** (p + q)
scope.remove(n0, a ** (p + q))
# Remove the right node
scope.remove(n1)
return nary_node('*', scope)
return scope.as_nary_node()
MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
......@@ -116,7 +115,7 @@ def match_duplicate_exponent(node):
left, right = node
if left.is_op(OP_MUL):
return [P(node, duplicate_exponent, (left.get_scope(), right))]
return [P(node, duplicate_exponent, (list(Scope(left)), right))]
return []
......@@ -159,7 +158,7 @@ def match_extend_exponent(node):
left, right = node
if right.is_numeric():
for n in node.get_scope():
for n in Scope(node):
if n.is_op(OP_ADD):
return [P(node, extend_exponent, (left, right))]
......
from ..node import ExpressionNode as Node
def nary_node(operator, scope):
"""
Create a binary expression tree for an n-ary operator. Takes the operator
and a list of expression nodes as arguments.
"""
return scope[0] if len(scope) == 1 \
else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
def gcd(a, b):
"""
Return greatest common divisor using Euclid's Algorithm.
......
from rules.utils import nary_node
class Scope(object):
def __init__(self, node):
self.node = node
self.nodes = node.get_scope()
def remove(self, node, replacement=None):
if node.is_leaf():
node_cmp = hash(node)
else:
node_cmp = node
for i, n in enumerate(self.nodes):
if n.is_leaf():
n_cmp = hash(n)
else:
n_cmp = n
if n_cmp == node_cmp:
if replacement != None:
self.nodes[i] = replacement
else:
del self.nodes[i]
return
raise ValueError('Node "%s" is not in the scope of "%s".'
% (node, self.node))
def as_nary_node(self):
return nary_node(self.node.value, self.nodes)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
nary_node, get_scope, OP_ADD
from tests.rulestestcase import RulesTestCase, tree
from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
from tests.rulestestcase import tree
class TestNode(unittest.TestCase):
class TestNode(RulesTestCase):
def setUp(self):
self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___lt__(self):
self.assertTrue(L(1) < L(2))
......@@ -95,19 +98,19 @@ class TestNode(unittest.TestCase):
def test_get_scope_binary(self):
plus = N('+', *self.l[:2])
self.assertEqual(plus.get_scope(), self.l[:2])
self.assertEqual(get_scope(plus), self.l[:2])
def test_get_scope_nested_left(self):
plus = N('+', N('+', *self.l[:2]), self.l[2])
self.assertEqual(plus.get_scope(), self.l[:3])
self.assertEqual(get_scope(plus), self.l[:3])
def test_get_scope_nested_right(self):
plus = N('+', self.l[0], N('+', *self.l[1:3]))
self.assertEqual(plus.get_scope(), self.l[:3])
self.assertEqual(get_scope(plus), self.l[:3])
def test_get_scope_nested_deep(self):
plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
self.assertEqual(plus.get_scope(), self.l)
self.assertEqual(get_scope(plus), self.l)
def test_equals_node_leaf(self):
a, b = plus = tree('a + b')
......@@ -168,3 +171,36 @@ class TestNode(unittest.TestCase):
m0, m1 = tree('-5 * -3,-5 * 6')
self.assertFalse(m0.equals(m1))
def test_scope___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_scope_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_scope_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_scope_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_scope_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_nary_node(self):
a, b, c, d = tree('a,b,c,d')
self.assertEqualNodes(nary_node('+', [a]), a)
self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
self.assertEqualNodes(nary_node('+', [a, b, c]),
N('+', N('+', a, b), c))
self.assertEqualNodes(nary_node('+', [a, b, c, d]),
N('+', N('+', N('+', a, b), c), d))
def test_scope_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
import unittest
from src.node import ExpressionNode as N, ExpressionLeaf as L
from src.rules.utils import nary_node
class TestRules(unittest.TestCase):
def test_nary_node_binary(self):
l0, l1 = L(1), L(2)
plus = N('+', l0, l1)
self.assertEqual(nary_node('+', [l0, l1]), plus)
def test_nary_node_ternary(self):
l0, l1, l2 = L(1), L(2), L(3)
plus = N('+', N('+', l0, l1), l2)
self.assertEqual(nary_node('+', [l0, l1, l2]), plus)
from src.node import ExpressionNode as N
from src.rules.utils import nary_node, least_common_multiple
from tests.rulestestcase import RulesTestCase, tree
import unittest
from src.rules.utils import least_common_multiple
class TestRulesUtils(RulesTestCase):
def test_nary_node(self):
a, b, c, d = tree('a,b,c,d')
self.assertEqualNodes(nary_node('+', [a]), a)
self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
self.assertEqualNodes(nary_node('+', [a, b, c]),
N('+', N('+', a, b), c))
self.assertEqualNodes(nary_node('+', [a, b, c, d]),
N('+', N('+', N('+', a, b), c), d))
class TestRulesUtils(unittest.TestCase):
def test_least_common_multiple(self):
self.assertEqual(least_common_multiple(5, 6), 30)
......
import unittest
from src.scope import Scope
from tests.rulestestcase import RulesTestCase, tree
class TestScope(RulesTestCase):
def setUp(self):
self.n, self.f = tree('a + b + cd,f')
(self.a, self.b), self.cd = self.n
self.c, self.d = self.cd
self.scope = Scope(self.n)
def test___init__(self):
self.assertEqual(self.scope.node, self.n)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
def test_remove_leaf(self):
self.scope.remove(self.b)
self.assertEqual(self.scope.nodes, [self.a, self.cd])
def test_remove_node(self):
self.scope.remove(self.cd)
self.assertEqual(self.scope.nodes, [self.a, self.b])
def test_remove_replace(self):
self.scope.remove(self.cd, self.f)
self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
def test_remove_error(self):
with self.assertRaises(ValueError):
self.scope.remove(self.f)
def test_as_nary_node(self):
self.assertEqualNodes(self.scope.as_nary_node(), self.n)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment