فهرست منبع

All derivative rules now keep using the same derivative operator al the way through the calclation.

Taddeus Kroes 13 سال پیش
والد
کامیت
c3a2ee8a5a
3فایلهای تغییر یافته به همراه18 افزوده شده و 14 حذف شده
  1. 15 13
      src/rules/derivatives.py
  2. 1 1
      src/rules/groups.py
  3. 2 0
      tests/test_rules_derivatives.py

+ 15 - 13
src/rules/derivatives.py

@@ -26,6 +26,14 @@ def second_arg(node):
     return node[1] if len(node) > 1 else None
 
 
+def same_der(root, new):
+    """
+    Replace root with a new derrivative, using the same operator as root
+    (OP_PRIME or DXDER).
+    """
+    return der(new, second_arg(root))
+
+
 def get_derivation_variable(node, variables=None):
     """
     Find the variable to derive over.
@@ -56,9 +64,8 @@ def chain_rule(root, args):
     as der(g(x), x).
     """
     g, f_deriv, f_deriv_args = args
-    x = root[1] if len(root) > 1 else None
 
-    return f_deriv(root, f_deriv_args) * der(g, x)
+    return f_deriv(root, f_deriv_args) * same_der(root, g)
 
 
 MESSAGES[chain_rule] = _('Apply the chain rule to {0}.')
@@ -144,7 +151,7 @@ def const_deriv_multiplication(root, args):
 
     scope.remove(c)
 
-    return c * der(scope.as_nary_node(), x)
+    return c * same_der(root, scope.as_nary_node())
 
 
 MESSAGES[const_deriv_multiplication] = \
@@ -188,7 +195,7 @@ def power_rule(root, args):
     """
     [f(x) ^ g(x)]'  ->  [e ^ ln(f(x) ^ g(x))]'
     """
-    return der(L(E) ** ln(root[0]), second_arg(root))
+    return same_der(root, L(E) ** ln(root[0]))
 
 
 MESSAGES[power_rule] = \
@@ -326,7 +333,7 @@ def tangens(root, args):
     """
     x = root[0][0]
 
-    return der(sin(x) / cos(x), second_arg(root))
+    return same_der(root, sin(x) / cos(x))
 
 
 MESSAGES[tangens] = \
@@ -365,11 +372,9 @@ def sum_rule(root, args):
     [f(x) + g(x)]'  ->  f'(x) + g'(x)
     """
     scope, f = args
-    x = second_arg(root)
-
     scope.remove(f)
 
-    return der(f, x) + der(scope.as_nary_node(), x)
+    return same_der(root, f) + same_der(root, scope.as_nary_node())
 
 
 MESSAGES[sum_rule] = _('Apply the sum rule to {0}.')
@@ -383,12 +388,10 @@ def product_rule(root, args):
     [f(x) * g(x) * h(x)]'  ->  f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]'
     """
     scope, f = args
-    x = second_arg(root)
-
     scope.remove(f)
     gh = scope.as_nary_node()
 
-    return der(f, x) * gh + f * der(gh, x)
+    return same_der(root, f) * gh + f * same_der(root, gh)
 
 
 MESSAGES[product_rule] = _('Apply the product rule to {0}.')
@@ -419,9 +422,8 @@ def quotient_rule(root, args):
     [f(x) / g(x)]'  ->  (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
     """
     f, g = root[0]
-    x = second_arg(root)
 
-    return (der(f, x) * g - f * der(g, x)) / g ** 2
+    return (same_der(root, f) * g - f * same_der(root, g)) / g ** 2
 
 
 MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.')

+ 1 - 1
src/rules/groups.py

@@ -43,7 +43,7 @@ def match_combine_groups(node):
         if not n.is_numeric():
             groups.append((Leaf(1), n, n, True))
 
-        # Each number multiplication yields a group, multiple occurences of
+        # Each number multiplication yields a group, multiple occurrences of
         # the same group can be replaced by a single one
         if n.is_op(OP_MUL):
             n_scope = Scope(n)

+ 2 - 0
tests/test_rules_derivatives.py

@@ -85,6 +85,8 @@ class TestRulesDerivatives(RulesTestCase):
         self.assertEqual(const_deriv_multiplication(root, args),
                          l2 * der(x, x))
 
+        self.assertRewrite(["[2x]'", "2[x]'", '2 * 1', '2'])
+
     def test_match_variable_power(self):
         root, x, l2 = tree('d/dx x ^ 2, x, 2')
         self.assertEqualPos(match_variable_power(root),