Forráskód Böngészése

Fixed merge conflict.

Sander Mathijs van Veen 14 éve
szülő
commit
d1eb149c2d

+ 73 - 19
src/node.py

@@ -260,23 +260,6 @@ class ExpressionNode(Node, ExpressionBase):
             return (self[0], self[1], ExpressionLeaf(1))
         return (self[1], self[0], ExpressionLeaf(1))
 
-    def get_scope(self):
-        """
-        Find all n nodes within the n-ary scope of this operator.
-        """
-        scope = []
-        #op = OP_ADD | OP_SUB if self.op & (OP_ADD | OP_SUB) else self.op
-
-        # TODO: what to do with OP_SUB and OP_ADD in get_scope?
-
-        for child in self:
-            if not child.is_leaf() and child.op == self.op:
-                scope += child.get_scope()
-            else:
-                scope.append(child)
-
-        return scope
-
     def equals(self, other):
         """
         Perform a non-strict equivalence check between two nodes:
@@ -292,8 +275,8 @@ class ExpressionNode(Node, ExpressionBase):
             return False
 
         if self.op in (OP_ADD, OP_MUL):
-            s0 = self.get_scope()
-            s1 = set(other.get_scope())
+            s0 = Scope(self)
+            s1 = set(Scope(other))
 
             # Scopes sould be of equal size
             if len(s0) != len(s1):
@@ -354,3 +337,74 @@ class ExpressionLeaf(Leaf, ExpressionBase):
         """
         # rule: 1 * r ^ 1 -> (1, r, 1)
         return (ExpressionLeaf(1), self, ExpressionLeaf(1))
+
+
+class Scope(object):
+
+    def __init__(self, node):
+        self.node = node
+        self.nodes = get_scope(node)
+
+    def __getitem__(self, key):
+        return self.nodes[key]
+
+    def __setitem__(self, key, value):
+        self.nodes[key] = value
+
+    def __len__(self):
+        return len(self.nodes)
+
+    def __iter__(self):
+        return iter(self.nodes)
+
+    def remove(self, node, replacement=None):
+        if node.is_leaf():
+            node_cmp = hash(node)
+        else:
+            node_cmp = node
+
+        for i, n in enumerate(self.nodes):
+            if n.is_leaf():
+                n_cmp = hash(n)
+            else:
+                n_cmp = n
+
+            if n_cmp == node_cmp:
+                if replacement != None:
+                    self[i] = replacement
+                else:
+                    del self.nodes[i]
+
+                return
+
+        raise ValueError('Node "%s" is not in the scope of "%s".'
+                         % (node, self.node))
+
+    def as_nary_node(self):
+        return nary_node(self.node.value, self.nodes)
+
+
+def nary_node(operator, scope):
+    """
+    Create a binary expression tree for an n-ary operator. Takes the operator
+    and a list of expression nodes as arguments.
+    """
+    if len(scope) == 1:
+        return scope[0]
+
+    return ExpressionNode(operator, nary_node(operator, scope[:-1]), scope[-1])
+
+
+def get_scope(node):
+    """
+    Find all n nodes within the n-ary scope of an operator node.
+    """
+    scope = []
+
+    for child in node:
+        if child.is_op(node.op):
+            scope += get_scope(child)
+        else:
+            scope.append(child)
+
+    return scope

+ 9 - 10
src/rules/factors.py

@@ -1,7 +1,6 @@
 from itertools import product, combinations
 
-from .utils import nary_node
-from ..node import OP_ADD, OP_MUL, OP_NEG
+from ..node import Scope, OP_ADD, OP_MUL, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -18,7 +17,7 @@ def match_expand(node):
     leaves = []
     additions = []
 
-    for n in node.get_scope():
+    for n in Scope(node):
         if n.is_leaf() or n.is_op(OP_NEG) and n[0].is_leaf():
             leaves.append(n)
         elif n.op == OP_ADD:
@@ -43,15 +42,15 @@ def expand_single(root, args):
     """
     a, bc = args
     b, c = bc
-    scope = root.get_scope()
+    scope = Scope(root)
 
     # Replace 'a' with the new expression
-    scope[scope.index(a)] = a * b + a * c
+    scope.remove(a, a * b + a * c)
 
     # Remove the addition
     scope.remove(bc)
 
-    return nary_node('*', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[expand_single] = _('Expand {1}({2}) to {1}({2[0]}) + {1}({2[1]}).')
@@ -64,15 +63,15 @@ def expand_double(root, args):
     (a + b) * (c + d) -> ac + ad + bc + bd
     """
     (a, b), (c, d) = ab, cd = args
-    scope = root.get_scope()
+    scope = Scope(root)
 
-    # Replace 'b + c' with the new expression
-    scope[scope.index(ab)] = a * c + a * d + b * c + b * d
+    # Replace 'a + b' with the new expression
+    scope.remove(ab, a * c + a * d + b * c + b * d)
 
     # Remove the right addition
     scope.remove(cd)
 
-    return nary_node('*', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[expand_double] = _('Expand ({1})({2}) to {1[0]}{2[0]} + {1[0]}{2[1]}'

+ 10 - 12
src/rules/fractions.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
-from .utils import nary_node, least_common_multiple
-from ..node import ExpressionLeaf as L, OP_DIV, OP_ADD, OP_MUL, OP_NEG
+from .utils import least_common_multiple
+from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -84,7 +84,7 @@ def match_add_constant_fractions(node):
         return node.is_op(OP_DIV) or \
                 (node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
 
-    fractions = filter(is_division, node.get_scope())
+    fractions = filter(is_division, Scope(node))
 
     for a, b in combinations(fractions, 2):
         if a.is_op(OP_NEG):
@@ -117,7 +117,7 @@ def equalize_denominators(root, args):
     """
     denom = args[2]
 
-    scope = root.get_scope()
+    scope = Scope(root)
 
     for fraction in args[:2]:
         n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction
@@ -127,11 +127,11 @@ def equalize_denominators(root, args):
             n = L(n.value * mult) if n.is_numeric() else L(mult) * n
 
             if fraction.is_op(OP_NEG):
-                scope[scope.index(fraction)] = -(n / L(d.value * mult))
+                scope.remove(fraction, -(n / L(d.value * mult)))
             else:
-                scope[scope.index(fraction)] = n / L(d.value * mult)
+                scope.remove(fraction, n / L(d.value * mult))
 
-    return nary_node('+', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[equalize_denominators] = _('Equalize the denominators of division'
@@ -157,17 +157,15 @@ def add_nominators(root, args):
     else:
         c = cb[0]
 
-    substitution = (a + c) / b
-
-    scope = root.get_scope()
+    scope = Scope(root)
 
     # Replace the left node with the new expression
-    scope[scope.index(ab)] = substitution
+    scope.remove(ab, (a + c) / b)
 
     # Remove the right node
     scope.remove(cb)
 
-    return nary_node('+', scope)
+    return scope.as_nary_node()
 
 
 # TODO: convert this to a lambda. Example: 22 / 77 - 28 / 77. the "-" is above

+ 7 - 8
src/rules/groups.py

@@ -1,10 +1,9 @@
 from itertools import combinations
 
-from ..node import OP_ADD, OP_MUL, ExpressionNode as Node, \
-        ExpressionLeaf as Leaf
+from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
+        OP_ADD, OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
-from .utils import nary_node
 
 
 def match_combine_groups(node):
@@ -25,13 +24,13 @@ def match_combine_groups(node):
     p = []
     groups = []
 
-    for n in node.get_scope():
+    for n in Scope(node):
         groups.append((1, n, n))
 
         # Each number multiplication yields a group, multiple occurences of
         # the same group can be replaced by a single one
         if n.is_op(OP_MUL):
-            scope = n.get_scope()
+            scope = Scope(n)
             l = len(scope)
 
             for i, sub_node in enumerate(scope):
@@ -55,18 +54,18 @@ def match_combine_groups(node):
 def combine_groups(root, args):
     c0, g0, n0, c1, g1, n1 = args
 
-    scope = root.get_scope()
+    scope = Scope(root)
 
     if not isinstance(c0, Leaf):
         c0 = Leaf(c0)
 
     # Replace the left node with the new expression
-    scope[scope.index(n0)] = (c0 + c1) * g0
+    scope.remove(n0, (c0 + c1) * g0)
 
     # Remove the right node
     scope.remove(n1)
 
-    return nary_node('+', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[combine_groups] = \

+ 7 - 7
src/rules/numerics.py

@@ -1,7 +1,7 @@
 from itertools import combinations
 
-from .utils import nary_node
-from ..node import ExpressionLeaf as Leaf, OP_DIV, OP_MUL, OP_NEG
+from ..node import ExpressionLeaf as Leaf, Scope, nary_node, OP_DIV, OP_MUL, \
+        OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -28,15 +28,15 @@ def add_numerics(root, args):
     else:
         c1 = c1.value
 
-    scope = root.get_scope()
+    scope = Scope(root)
 
     # Replace the left node with the new expression
-    scope[scope.index(n0)] = Leaf(c0 + c1)
+    scope.remove(n0, Leaf(c0 + c1))
 
     # Remove the right node
     scope.remove(n1)
 
-    return nary_node('+', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[add_numerics] = _('Combine the constants {1} and {2}, which'
@@ -119,7 +119,7 @@ def match_multiply_numerics(node):
     p = []
     numerics = []
 
-    for n in node.get_scope():
+    for n in Scope(node):
         if n.is_numeric():
             numerics.append((n, n.value))
         elif n.is_op(OP_NEG) and n[0].is_numeric():
@@ -147,7 +147,7 @@ def multiply_numerics(root, args):
     else:
         substitution = -Leaf(-value)
 
-    for n in root.get_scope():
+    for n in Scope(root):
         if hash(n) == hash(n0):
             # Replace the left node with the new expression
             scope.append(substitution)

+ 8 - 11
src/rules/poly.py

@@ -1,8 +1,7 @@
 from itertools import combinations
 
-from ..node import OP_ADD, OP_NEG
+from ..node import Scope, OP_ADD, OP_NEG
 from ..possibilities import Possibility as P, MESSAGES
-from .utils import nary_node
 from .numerics import add_numerics
 
 
@@ -32,7 +31,7 @@ def match_combine_polynomes(node, verbose=False):
     if verbose:  # pragma: nocover
         print 'match combine factors:', node
 
-    for n in node.get_scope():
+    for n in Scope(node):
         polynome = n.extract_polynome_properties()
 
         if verbose:  # pragma: nocover
@@ -84,16 +83,14 @@ def combine_polynomes(root, args):
     else:
         power = r ** e
 
-    # replacement: (c0 + c1) * a ^ b
-    # a, b and c are from 'left', d is from 'right'.
-    replacement = (c0 + c1) * power
-
-    scope = root.get_scope()
+    scope = Scope(root)
 
-    # Replace the left node with the new expression
-    scope[scope.index(n0)] = replacement
+    # Replace the left node with the new expression:
+    # (c0 + c1) * a ^ b
+    # a, b and c are from 'left', d is from 'right'.
+    scope.remove(n0, (c0 + c1) * power)
 
     # Remove the right node
     scope.remove(n1)
 
-    return nary_node('+', scope)
+    return scope.as_nary_node()

+ 7 - 8
src/rules/powers.py

@@ -1,9 +1,8 @@
 from itertools import combinations
 
-from ..node import ExpressionNode as N, ExpressionLeaf as L, \
+from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
                    OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD
 from ..possibilities import Possibility as P, MESSAGES
-from .utils import nary_node
 from ..translate import _
 
 
@@ -19,7 +18,7 @@ def match_add_exponents(node):
     p = []
     powers = {}
 
-    for n in node.get_scope():
+    for n in Scope(node):
         if n.is_identifier():
             s = n
             exponent = L(1)
@@ -52,15 +51,15 @@ def add_exponents(root, args):
     a^p * a^q  ->  a^(p + q)
     """
     n0, n1, a, p, q = args
-    scope = root.get_scope()
+    scope = Scope(root)
 
     # Replace the left node with the new expression
-    scope[scope.index(n0)] = a ** (p + q)
+    scope.remove(n0, a ** (p + q))
 
     # Remove the right node
     scope.remove(n1)
 
-    return nary_node('*', scope)
+    return scope.as_nary_node()
 
 
 MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}, which'
@@ -116,7 +115,7 @@ def match_duplicate_exponent(node):
     left, right = node
 
     if left.is_op(OP_MUL):
-        return [P(node, duplicate_exponent, (left.get_scope(), right))]
+        return [P(node, duplicate_exponent, (list(Scope(left)), right))]
 
     return []
 
@@ -159,7 +158,7 @@ def match_extend_exponent(node):
     left, right = node
 
     if right.is_numeric():
-        for n in node.get_scope():
+        for n in Scope(node):
             if n.is_op(OP_ADD):
                 return [P(node, extend_exponent, (left, right))]
 

+ 0 - 12
src/rules/utils.py

@@ -1,15 +1,3 @@
-from ..node import ExpressionNode as Node
-
-
-def nary_node(operator, scope):
-    """
-    Create a binary expression tree for an n-ary operator. Takes the operator
-    and a list of expression nodes as arguments.
-    """
-    return scope[0] if len(scope) == 1 \
-           else Node(operator, nary_node(operator, scope[:-1]), scope[-1])
-
-
 def gcd(a, b):
     """
     Return greatest common divisor using Euclid's Algorithm.

+ 0 - 34
src/scope.py

@@ -1,34 +0,0 @@
-from rules.utils import nary_node
-
-
-class Scope(object):
-
-    def __init__(self, node):
-        self.node = node
-        self.nodes = node.get_scope()
-
-    def remove(self, node, replacement=None):
-        if node.is_leaf():
-            node_cmp = hash(node)
-        else:
-            node_cmp = node
-
-        for i, n in enumerate(self.nodes):
-            if n.is_leaf():
-                n_cmp = hash(n)
-            else:
-                n_cmp = n
-
-            if n_cmp == node_cmp:
-                if replacement != None:
-                    self.nodes[i] = replacement
-                else:
-                    del self.nodes[i]
-
-                return
-
-        raise ValueError('Node "%s" is not in the scope of "%s".'
-                         % (node, self.node))
-
-    def as_nary_node(self):
-        return nary_node(self.node.value, self.nodes)

+ 45 - 9
tests/test_node.py

@@ -1,13 +1,16 @@
-import unittest
+from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
+        nary_node, get_scope, OP_ADD
+from tests.rulestestcase import RulesTestCase, tree
 
-from src.node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD
-from tests.rulestestcase import tree
 
-
-class TestNode(unittest.TestCase):
+class TestNode(RulesTestCase):
 
     def setUp(self):
         self.l = [L(1), N('*', L(2), L(3)), L(4), L(5)]
+        self.n, self.f = tree('a + b + cd,f')
+        (self.a, self.b), self.cd = self.n
+        self.c, self.d = self.cd
+        self.scope = Scope(self.n)
 
     def test___lt__(self):
         self.assertTrue(L(1) < L(2))
@@ -95,19 +98,19 @@ class TestNode(unittest.TestCase):
 
     def test_get_scope_binary(self):
         plus = N('+', *self.l[:2])
-        self.assertEqual(plus.get_scope(), self.l[:2])
+        self.assertEqual(get_scope(plus), self.l[:2])
 
     def test_get_scope_nested_left(self):
         plus = N('+', N('+', *self.l[:2]), self.l[2])
-        self.assertEqual(plus.get_scope(), self.l[:3])
+        self.assertEqual(get_scope(plus), self.l[:3])
 
     def test_get_scope_nested_right(self):
         plus = N('+', self.l[0], N('+', *self.l[1:3]))
-        self.assertEqual(plus.get_scope(), self.l[:3])
+        self.assertEqual(get_scope(plus), self.l[:3])
 
     def test_get_scope_nested_deep(self):
         plus = N('+', N('+', N('+', *self.l[:2]), self.l[2]), self.l[3])
-        self.assertEqual(plus.get_scope(), self.l)
+        self.assertEqual(get_scope(plus), self.l)
 
     def test_equals_node_leaf(self):
         a, b = plus = tree('a + b')
@@ -168,3 +171,36 @@ class TestNode(unittest.TestCase):
 
         m0, m1 = tree('-5 * -3,-5 * 6')
         self.assertFalse(m0.equals(m1))
+
+    def test_scope___init__(self):
+        self.assertEqual(self.scope.node, self.n)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
+
+    def test_scope_remove_leaf(self):
+        self.scope.remove(self.b)
+        self.assertEqual(self.scope.nodes, [self.a, self.cd])
+
+    def test_scope_remove_node(self):
+        self.scope.remove(self.cd)
+        self.assertEqual(self.scope.nodes, [self.a, self.b])
+
+    def test_scope_remove_replace(self):
+        self.scope.remove(self.cd, self.f)
+        self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
+
+    def test_scope_remove_error(self):
+        with self.assertRaises(ValueError):
+            self.scope.remove(self.f)
+
+    def test_nary_node(self):
+        a, b, c, d = tree('a,b,c,d')
+
+        self.assertEqualNodes(nary_node('+', [a]), a)
+        self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
+        self.assertEqualNodes(nary_node('+', [a, b, c]),
+                              N('+', N('+', a, b), c))
+        self.assertEqualNodes(nary_node('+', [a, b, c, d]),
+                              N('+', N('+', N('+', a, b), c), d))
+
+    def test_scope_as_nary_node(self):
+        self.assertEqualNodes(self.scope.as_nary_node(), self.n)

+ 0 - 17
tests/test_rules.py

@@ -1,17 +0,0 @@
-import unittest
-
-from src.node import ExpressionNode as N, ExpressionLeaf as L
-from src.rules.utils import nary_node
-
-
-class TestRules(unittest.TestCase):
-
-    def test_nary_node_binary(self):
-        l0, l1 = L(1), L(2)
-        plus = N('+', l0, l1)
-        self.assertEqual(nary_node('+', [l0, l1]), plus)
-
-    def test_nary_node_ternary(self):
-        l0, l1, l2 = L(1), L(2), L(3)
-        plus = N('+', N('+', l0, l1), l2)
-        self.assertEqual(nary_node('+', [l0, l1, l2]), plus)

+ 3 - 13
tests/test_rules_utils.py

@@ -1,19 +1,9 @@
-from src.node import ExpressionNode as N
-from src.rules.utils import nary_node, least_common_multiple
-from tests.rulestestcase import RulesTestCase, tree
+import unittest
 
+from src.rules.utils import least_common_multiple
 
-class TestRulesUtils(RulesTestCase):
 
-    def test_nary_node(self):
-        a, b, c, d = tree('a,b,c,d')
-
-        self.assertEqualNodes(nary_node('+', [a]), a)
-        self.assertEqualNodes(nary_node('+', [a, b]), N('+', a, b))
-        self.assertEqualNodes(nary_node('+', [a, b, c]),
-                              N('+', N('+', a, b), c))
-        self.assertEqualNodes(nary_node('+', [a, b, c, d]),
-                              N('+', N('+', N('+', a, b), c), d))
+class TestRulesUtils(unittest.TestCase):
 
     def test_least_common_multiple(self):
         self.assertEqual(least_common_multiple(5, 6), 30)

+ 0 - 36
tests/test_scope.py

@@ -1,36 +0,0 @@
-import unittest
-
-from src.scope import Scope
-from tests.rulestestcase import RulesTestCase, tree
-
-
-class TestScope(RulesTestCase):
-
-    def setUp(self):
-        self.n, self.f = tree('a + b + cd,f')
-        (self.a, self.b), self.cd = self.n
-        self.c, self.d = self.cd
-        self.scope = Scope(self.n)
-
-    def test___init__(self):
-        self.assertEqual(self.scope.node, self.n)
-        self.assertEqual(self.scope.nodes, [self.a, self.b, self.cd])
-
-    def test_remove_leaf(self):
-        self.scope.remove(self.b)
-        self.assertEqual(self.scope.nodes, [self.a, self.cd])
-
-    def test_remove_node(self):
-        self.scope.remove(self.cd)
-        self.assertEqual(self.scope.nodes, [self.a, self.b])
-
-    def test_remove_replace(self):
-        self.scope.remove(self.cd, self.f)
-        self.assertEqual(self.scope.nodes, [self.a, self.b, self.f])
-
-    def test_remove_error(self):
-        with self.assertRaises(ValueError):
-            self.scope.remove(self.f)
-
-    def test_as_nary_node(self):
-        self.assertEqualNodes(self.scope.as_nary_node(), self.n)