瀏覽代碼

Added negation checks to sinus + cosinus quadrants rule.

Taddeus Kroes 14 年之前
父節點
當前提交
9396116f82
共有 2 個文件被更改,包括 47 次插入13 次删除
  1. 18 8
      src/rules/goniometry.py
  2. 29 5
      tests/test_rules_goniometry.py

+ 18 - 8
src/rules/goniometry.py

@@ -1,5 +1,8 @@
+from itertools import permutations
+
 from ..node import ExpressionNode as N, ExpressionLeaf as L, OP_ADD, OP_MUL, \
-        OP_DIV, OP_SIN, OP_COS, OP_TAN, OP_SQRT, PI, TYPE_OPERATOR, sin, cos
+        OP_DIV, OP_SIN, OP_COS, OP_TAN, OP_SQRT, PI, TYPE_OPERATOR, sin, cos, \
+        Scope
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -11,13 +14,16 @@ def match_add_quadrants(node):
     assert node.is_op(OP_ADD)
 
     p = []
-    sin_q, cos_q = node
+    scope = Scope(node)
 
-    if sin_q.is_power(2) and cos_q.is_power(2):
-        sinus, cosinus = sin_q[0], cos_q[0]
+    for sin_q, cos_q in permutations(scope, 2):
+        if sin_q.is_power(2) and cos_q.is_power(2) \
+                and not sin_q.negated and not cos_q.negated:
+            s, c = sin_q[0], cos_q[0]
 
-        if sinus.is_op(OP_SIN) and cosinus.is_op(OP_COS):
-            p.append(P(node, add_quadrants, ()))
+            if s.is_op(OP_SIN) and c.is_op(OP_COS) and not s.negated \
+                    and not c.negated and s[0] == c[0]:
+                p.append(P(node, add_quadrants, (scope, sin_q, cos_q)))
 
     return p
 
@@ -26,7 +32,11 @@ def add_quadrants(root, args):
     """
     sin(t) ^ 2 + cos(t) ^ 2  ->  1
     """
-    return L(1)
+    scope, s, c = args
+    scope.replace(s, L(1))
+    scope.remove(c)
+
+    return scope.as_nary_node()
 
 
 MESSAGES[add_quadrants] = _('Add the sinus and cosinus quadrants to 1.')
@@ -130,7 +140,7 @@ CONSTANTS = {
 
 def match_standard_radian(node):
     """
-    Apply a direct constant calculation from the constants table.
+    Apply a direct constant calculation from the constants table:
 
         | 0 | pi / 6    | pi / 4    | pi / 3    | pi / 2
     ----+---+-----------+-----------+-----------+-------

+ 29 - 5
tests/test_rules_goniometry.py

@@ -2,7 +2,7 @@
 from src.rules.goniometry import match_add_quadrants, add_quadrants, \
         match_negated_parameter, negated_sinus_parameter, is_pi_frac, \
         negated_cosinus_parameter, match_standard_radian, standard_radian
-from src.node import PI, OP_SIN, OP_COS, OP_TAN, sin, cos, tan
+from src.node import PI, OP_SIN, OP_COS, OP_TAN, sin, cos, tan, Scope
 from src.possibilities import Possibility as P
 from tests.rulestestcase import RulesTestCase, tree
 from src.rules import goniometry
@@ -15,12 +15,36 @@ class TestRulesGoniometry(RulesTestCase):
         self.assertEqual(doctest.testmod(m=goniometry)[0], 0)
 
     def test_match_add_quadrants(self):
-        root = tree('sin(t) ^ 2 + cos(t) ^ 2')
-        possibilities = match_add_quadrants(root)
-        self.assertEqualPos(possibilities, [P(root, add_quadrants, ())])
+        s, c = root = tree('sin(t) ^ 2 + cos(t) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root),
+                [P(root, add_quadrants, (Scope(root), s, c))])
+
+        c, s = root = tree('cos(t) ^ 2 + sin(t) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root),
+                [P(root, add_quadrants, (Scope(root), s, c))])
+
+        (s, a), c = root = tree('sin(t) ^ 2 + a + cos(t) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root),
+                [P(root, add_quadrants, (Scope(root), s, c))])
+
+        (s, c0), c1 = root = tree('sin(t) ^ 2 + cos(t) ^ 2 + cos(t) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root),
+                [P(root, add_quadrants, (Scope(root), s, c0)),
+                 P(root, add_quadrants, (Scope(root), s, c1))])
+
+        root = tree('sin(t) ^ 2 + cos(y) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root), [])
+
+        root = tree('sin(t) ^ 2 - cos(t) ^ 2')
+        self.assertEqualPos(match_add_quadrants(root), [])
 
     def test_add_quadrants(self):
-        self.assertEqual(add_quadrants(None, ()), 1)
+        s, c = root = tree('sin(t) ^ 2 + cos(t) ^ 2')
+        self.assertEqual(add_quadrants(root, (Scope(root), s, c)), 1)
+
+        root, expect = tree('cos(t) ^ 2 + a + sin(t) ^ 2, a + 1')
+        (c, a), s = root
+        self.assertEqual(add_quadrants(root, (Scope(root), s, c)), expect)
 
     def test_match_negated_parameter(self):
         s, c = tree('sin -t, cos -t')