Forráskód Böngészése

Added some tests for chein rule.

Taddeus Kroes 14 éve
szülő
commit
4ffcd6599a
2 módosított fájl, 21 hozzáadás és 11 törlés
  1. 4 4
      src/rules/derivatives.py
  2. 17 7
      tests/test_rules_derivatives.py

+ 4 - 4
src/rules/derivatives.py

@@ -115,11 +115,12 @@ def match_const_deriv_multiplication(node):
     p = []
 
     if node[0].is_op(OP_MUL):
+        x = L(get_derivation_variable(node))
         scope = Scope(node[0])
 
         for n in scope:
-            if n.is_numeric():
-                p.append(P(node, const_deriv_multiplication, (scope, n)))
+            if not n.contains(x):
+                p.append(P(node, const_deriv_multiplication, (scope, n, x)))
 
     return p
 
@@ -128,10 +129,9 @@ def const_deriv_multiplication(root, args):
     """
     der(c * f(x), x)  ->  c * der(f(x), x)
     """
-    scope, c = args
+    scope, c, x = args
 
     scope.remove(c)
-    x = L(get_derivation_variable(root))
 
     return c * der(scope.as_nary_node(), x)
 

+ 17 - 7
tests/test_rules_derivatives.py

@@ -49,20 +49,24 @@ class TestRulesDerivatives(RulesTestCase):
         root = tree('der(2x)')
         l2, x = root[0]
         self.assertEqualPos(match_const_deriv_multiplication(root),
-                [P(root, const_deriv_multiplication, (Scope(root[0]), l2))])
+                [P(root, const_deriv_multiplication, (Scope(root[0]), l2, x))])
+
+        (x, y), x = root = tree('der(xy, x)')
+        self.assertEqualPos(match_const_deriv_multiplication(root),
+                [P(root, const_deriv_multiplication, (Scope(root[0]), y, x))])
 
     def test_match_const_deriv_multiplication_multiple_constants(self):
         root = tree('der(2x * 3)')
         (l2, x), l3 = root[0]
         scope = Scope(root[0])
         self.assertEqualPos(match_const_deriv_multiplication(root),
-                [P(root, const_deriv_multiplication, (scope, l2)),
-                 P(root, const_deriv_multiplication, (scope, l3))])
+                [P(root, const_deriv_multiplication, (scope, l2, x)),
+                 P(root, const_deriv_multiplication, (scope, l3, x))])
 
     def test_const_deriv_multiplication(self):
         root = tree('der(2x)')
         l2, x = root[0]
-        args = Scope(root[0]), l2
+        args = Scope(root[0]), l2, x
         self.assertEqual(const_deriv_multiplication(root, args),
                          l2 * der(x, x))
 
@@ -94,8 +98,14 @@ class TestRulesDerivatives(RulesTestCase):
         x, n = root[0]
         self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
 
-    def test_variable_root_chain_rule(self):
-        pass
+    def test_variable_exponent(self):
+        root = tree('der(2 ^ x)')
+        g, x = root[0]
+        self.assertEqual(variable_exponent(root, ()), g ** x * ln(g))
 
     def test_chain_rule(self):
-        pass
+        root = tree('der(2 ^ x ^ 3)')
+        l2, x3 = root[0]
+        x, l3 = x3
+        self.assertEqual(chain_rule(root, (x3, variable_exponent, ())),
+                          l2 ** x3 * ln(l2) * der(x3))