Sfoglia il codice sorgente

All derivative rules now keep using the same derivative operator al the way through the calclation.

Taddeus Kroes 13 anni fa
parent
commit
c3a2ee8a5a
3 ha cambiato i file con 18 aggiunte e 14 eliminazioni
  1. 15 13
      src/rules/derivatives.py
  2. 1 1
      src/rules/groups.py
  3. 2 0
      tests/test_rules_derivatives.py

+ 15 - 13
src/rules/derivatives.py

@@ -26,6 +26,14 @@ def second_arg(node):
     return node[1] if len(node) > 1 else None
 
 
+def same_der(root, new):
+    """
+    Replace root with a new derrivative, using the same operator as root
+    (OP_PRIME or DXDER).
+    """
+    return der(new, second_arg(root))
+
+
 def get_derivation_variable(node, variables=None):
     """
     Find the variable to derive over.
@@ -56,9 +64,8 @@ def chain_rule(root, args):
     as der(g(x), x).
     """
     g, f_deriv, f_deriv_args = args
-    x = root[1] if len(root) > 1 else None
 
-    return f_deriv(root, f_deriv_args) * der(g, x)
+    return f_deriv(root, f_deriv_args) * same_der(root, g)
 
 
 MESSAGES[chain_rule] = _('Apply the chain rule to {0}.')
@@ -144,7 +151,7 @@ def const_deriv_multiplication(root, args):
 
     scope.remove(c)
 
-    return c * der(scope.as_nary_node(), x)
+    return c * same_der(root, scope.as_nary_node())
 
 
 MESSAGES[const_deriv_multiplication] = \
@@ -188,7 +195,7 @@ def power_rule(root, args):
     """
     [f(x) ^ g(x)]'  ->  [e ^ ln(f(x) ^ g(x))]'
     """
-    return der(L(E) ** ln(root[0]), second_arg(root))
+    return same_der(root, L(E) ** ln(root[0]))
 
 
 MESSAGES[power_rule] = \
@@ -326,7 +333,7 @@ def tangens(root, args):
     """
     x = root[0][0]
 
-    return der(sin(x) / cos(x), second_arg(root))
+    return same_der(root, sin(x) / cos(x))
 
 
 MESSAGES[tangens] = \
@@ -365,11 +372,9 @@ def sum_rule(root, args):
     [f(x) + g(x)]'  ->  f'(x) + g'(x)
     """
     scope, f = args
-    x = second_arg(root)
-
     scope.remove(f)
 
-    return der(f, x) + der(scope.as_nary_node(), x)
+    return same_der(root, f) + same_der(root, scope.as_nary_node())
 
 
 MESSAGES[sum_rule] = _('Apply the sum rule to {0}.')
@@ -383,12 +388,10 @@ def product_rule(root, args):
     [f(x) * g(x) * h(x)]'  ->  f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]'
     """
     scope, f = args
-    x = second_arg(root)
-
     scope.remove(f)
     gh = scope.as_nary_node()
 
-    return der(f, x) * gh + f * der(gh, x)
+    return same_der(root, f) * gh + f * same_der(root, gh)
 
 
 MESSAGES[product_rule] = _('Apply the product rule to {0}.')
@@ -419,9 +422,8 @@ def quotient_rule(root, args):
     [f(x) / g(x)]'  ->  (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
     """
     f, g = root[0]
-    x = second_arg(root)
 
-    return (der(f, x) * g - f * der(g, x)) / g ** 2
+    return (same_der(root, f) * g - f * same_der(root, g)) / g ** 2
 
 
 MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.')

+ 1 - 1
src/rules/groups.py

@@ -43,7 +43,7 @@ def match_combine_groups(node):
         if not n.is_numeric():
             groups.append((Leaf(1), n, n, True))
 
-        # Each number multiplication yields a group, multiple occurences of
+        # Each number multiplication yields a group, multiple occurrences of
         # the same group can be replaced by a single one
         if n.is_op(OP_MUL):
             n_scope = Scope(n)

+ 2 - 0
tests/test_rules_derivatives.py

@@ -85,6 +85,8 @@ class TestRulesDerivatives(RulesTestCase):
         self.assertEqual(const_deriv_multiplication(root, args),
                          l2 * der(x, x))
 
+        self.assertRewrite(["[2x]'", "2[x]'", '2 * 1', '2'])
+
     def test_match_variable_power(self):
         root, x, l2 = tree('d/dx x ^ 2, x, 2')
         self.assertEqualPos(match_variable_power(root),