소스 검색

Fixed scope index replacement issue in fraction rule.

Taddeus Kroes 13 년 전
부모
커밋
a77a89778e
5개의 변경된 파일35개의 추가작업 그리고 13개의 파일을 삭제
  1. 9 1
      src/node.py
  2. 11 6
      src/rules/fractions.py
  3. 5 5
      src/rules/numerics.py
  4. 4 0
      src/rules/utils.py
  5. 6 1
      tests/test_rules_utils.py

+ 9 - 1
src/node.py

@@ -592,7 +592,7 @@ class Scope(object):
                 del self.nodes[i]
 
                 # Update remaining scope indices
-                for n in self[max(i, 1):]:
+                for n in self.nodes[i:]:
                     n.scope_index -= 1
         except AttributeError:
             raise ValueError('Node "%s" is not in the scope of "%s".'
@@ -603,6 +603,14 @@ class Scope(object):
 
     def as_nary_node(self):
         return nary_node(self.node.op, self.nodes).negate(self.node.negated)
+        #return negate(nary_node(self.node.op, self.nodes), self.node.negated)
+
+    def all_except(self, node):
+        before = range(0, node.scope_index)
+        after = range(node.scope_index + 1, len(self))
+        nodes = [self[i] for i in before + after]
+
+        return nary_node(self.node.op, nodes).negate(self.node.negated)
 
 
 def nary_node(operator, scope):

+ 11 - 6
src/rules/fractions.py

@@ -222,12 +222,12 @@ def multiply_with_fraction(root, args):
     scope, ab, c = args
     a, b = ab
 
-    if scope.index(ab) - scope.index(c) < 0:
-        replacement = a * c / b
+    if scope.index(ab) < scope.index(c):
+        nominator = a * c
     else:
-        replacement = c * a / b
+        nominator = c * a
 
-    scope.replace(ab, replacement.negate(ab.negated))
+    scope.replace(ab, negate(nominator / b, ab.negated))
     scope.remove(c)
 
     return scope.as_nary_node()
@@ -246,7 +246,7 @@ def match_divide_fractions(node):
     a / (b / c)      ->  ac / b
 
     Note that:
-    a / b / (c / d)  ->*  ad / bd  # chain test!
+    a / b / (c / d)  =>  ad / bd
     """
     assert node.is_op(OP_DIV)
 
@@ -347,7 +347,12 @@ def match_extract_fraction_terms(node):
     # ac / b
     for n in ifilterfalse(evals_to_numeric, n_scope):
         a_scope = mult_scope(nominator)
-        a = remove_from_mult_scope(a_scope, n)
+
+        #a = remove_from_mult_scope(a_scope, n)
+        if len(a_scope) == 1:
+            a = L(1)
+        else:
+            a = a_scope.all_except(n)
 
         if evals_to_numeric(a / denominator):
             p.append(P(node, extract_nominator_term, (a, n)))

+ 5 - 5
src/rules/numerics.py

@@ -2,7 +2,7 @@ from itertools import combinations
 
 from .utils import greatest_common_divisor, is_numeric_node
 from ..node import ExpressionLeaf as Leaf, Scope, OP_ADD, OP_DIV, OP_MUL, \
-        OP_POW
+        OP_POW, negate
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -61,7 +61,7 @@ def add_numerics(root, args):
     value = c0.actual_value() + c1.actual_value()
 
     # Replace the left node with the new expression
-    scope.replace(c0, Leaf(abs(value)).negate(int(value < 0)))
+    scope.replace(c0, Leaf(abs(value), negated=int(value < 0)))
 
     # Remove the right node
     scope.remove(c1)
@@ -141,7 +141,7 @@ def divide_numerics(root, args):
     """
     n, d = root
 
-    return Leaf(n.value / d.value).negate(root.negated)
+    return Leaf(n.value / d.value, negated=root.negated)
 
 
 MESSAGES[divide_numerics] = _('Constant division {0} reduces to a number.')
@@ -248,7 +248,7 @@ def multiply_numerics(root, args):
     scope, c0, c1 = args
 
     # Replace the left node with the new expression
-    substitution = Leaf(c0.value * c1.value).negate(c0.negated + c1.negated)
+    substitution = Leaf(c0.value * c1.value, negated=c0.negated + c1.negated)
     scope.replace(c0, substitution)
 
     # Remove the right node
@@ -284,7 +284,7 @@ def raise_numerics(root, args):
     """
     r, e = args
 
-    return Leaf(r.value ** e.value).negate(r.negated * e.value)
+    return Leaf(r.value ** e.value, negated=r.negated * e.value)
 
 
 MESSAGES[raise_numerics] = _('Raise constant {1} with {2}.')

+ 4 - 0
src/rules/utils.py

@@ -205,3 +205,7 @@ def iter_pairs(list_iterable):
 
     for i, left in enumerate(list_iterable[:-1]):
         yield left, list_iterable[i + 1]
+
+
+def range_except(start, end, exception):
+    return range(start, exception) + range(exception + 1, end)

+ 6 - 1
tests/test_rules_utils.py

@@ -2,7 +2,7 @@ from src.rules import utils
 from src.rules.utils import least_common_multiple, is_fraction, partition, \
         find_variables, first_sorted_variable, find_variable, substitute, \
         divides, dividers, is_prime, prime_dividers, evals_to_numeric, \
-        iter_pairs
+        iter_pairs, range_except
 from tests.rulestestcase import tree, RulesTestCase
 
 
@@ -110,3 +110,8 @@ class TestRulesUtils(RulesTestCase):
         self.assertEqual(list(iter_pairs([1, 2])), [(1, 2)])
         self.assertEqual(list(iter_pairs([1, 2, 3])), [(1, 2), (2, 3)])
         self.assertEqual(list(iter_pairs([1, 2, 3, 4])), [(1, 2), (2, 3), (3, 4)])
+
+    def test_range_except(self):
+        self.assertEqual(range_except(0, 5, 2), [0, 1, 3, 4])
+        self.assertEqual(range_except(0, 4, 0), [1, 2, 3])
+        self.assertEqual(range_except(0, 3, 3), [0, 1, 2])