Commit c3a2ee8a authored by Taddeüs Kroes's avatar Taddeüs Kroes

All derivative rules now keep using the same derivative operator al the way through the calclation.

parent 4c15f528
...@@ -26,6 +26,14 @@ def second_arg(node): ...@@ -26,6 +26,14 @@ def second_arg(node):
return node[1] if len(node) > 1 else None return node[1] if len(node) > 1 else None
def same_der(root, new):
"""
Replace root with a new derrivative, using the same operator as root
(OP_PRIME or DXDER).
"""
return der(new, second_arg(root))
def get_derivation_variable(node, variables=None): def get_derivation_variable(node, variables=None):
""" """
Find the variable to derive over. Find the variable to derive over.
...@@ -56,9 +64,8 @@ def chain_rule(root, args): ...@@ -56,9 +64,8 @@ def chain_rule(root, args):
as der(g(x), x). as der(g(x), x).
""" """
g, f_deriv, f_deriv_args = args g, f_deriv, f_deriv_args = args
x = root[1] if len(root) > 1 else None
return f_deriv(root, f_deriv_args) * der(g, x) return f_deriv(root, f_deriv_args) * same_der(root, g)
MESSAGES[chain_rule] = _('Apply the chain rule to {0}.') MESSAGES[chain_rule] = _('Apply the chain rule to {0}.')
...@@ -144,7 +151,7 @@ def const_deriv_multiplication(root, args): ...@@ -144,7 +151,7 @@ def const_deriv_multiplication(root, args):
scope.remove(c) scope.remove(c)
return c * der(scope.as_nary_node(), x) return c * same_der(root, scope.as_nary_node())
MESSAGES[const_deriv_multiplication] = \ MESSAGES[const_deriv_multiplication] = \
...@@ -188,7 +195,7 @@ def power_rule(root, args): ...@@ -188,7 +195,7 @@ def power_rule(root, args):
""" """
[f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]' [f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]'
""" """
return der(L(E) ** ln(root[0]), second_arg(root)) return same_der(root, L(E) ** ln(root[0]))
MESSAGES[power_rule] = \ MESSAGES[power_rule] = \
...@@ -326,7 +333,7 @@ def tangens(root, args): ...@@ -326,7 +333,7 @@ def tangens(root, args):
""" """
x = root[0][0] x = root[0][0]
return der(sin(x) / cos(x), second_arg(root)) return same_der(root, sin(x) / cos(x))
MESSAGES[tangens] = \ MESSAGES[tangens] = \
...@@ -365,11 +372,9 @@ def sum_rule(root, args): ...@@ -365,11 +372,9 @@ def sum_rule(root, args):
[f(x) + g(x)]' -> f'(x) + g'(x) [f(x) + g(x)]' -> f'(x) + g'(x)
""" """
scope, f = args scope, f = args
x = second_arg(root)
scope.remove(f) scope.remove(f)
return der(f, x) + der(scope.as_nary_node(), x) return same_der(root, f) + same_der(root, scope.as_nary_node())
MESSAGES[sum_rule] = _('Apply the sum rule to {0}.') MESSAGES[sum_rule] = _('Apply the sum rule to {0}.')
...@@ -383,12 +388,10 @@ def product_rule(root, args): ...@@ -383,12 +388,10 @@ def product_rule(root, args):
[f(x) * g(x) * h(x)]' -> f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]' [f(x) * g(x) * h(x)]' -> f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]'
""" """
scope, f = args scope, f = args
x = second_arg(root)
scope.remove(f) scope.remove(f)
gh = scope.as_nary_node() gh = scope.as_nary_node()
return der(f, x) * gh + f * der(gh, x) return same_der(root, f) * gh + f * same_der(root, gh)
MESSAGES[product_rule] = _('Apply the product rule to {0}.') MESSAGES[product_rule] = _('Apply the product rule to {0}.')
...@@ -419,9 +422,8 @@ def quotient_rule(root, args): ...@@ -419,9 +422,8 @@ def quotient_rule(root, args):
[f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2 [f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
""" """
f, g = root[0] f, g = root[0]
x = second_arg(root)
return (der(f, x) * g - f * der(g, x)) / g ** 2 return (same_der(root, f) * g - f * same_der(root, g)) / g ** 2
MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.') MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.')
...@@ -43,7 +43,7 @@ def match_combine_groups(node): ...@@ -43,7 +43,7 @@ def match_combine_groups(node):
if not n.is_numeric(): if not n.is_numeric():
groups.append((Leaf(1), n, n, True)) groups.append((Leaf(1), n, n, True))
# Each number multiplication yields a group, multiple occurences of # Each number multiplication yields a group, multiple occurrences of
# the same group can be replaced by a single one # the same group can be replaced by a single one
if n.is_op(OP_MUL): if n.is_op(OP_MUL):
n_scope = Scope(n) n_scope = Scope(n)
......
...@@ -85,6 +85,8 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -85,6 +85,8 @@ class TestRulesDerivatives(RulesTestCase):
self.assertEqual(const_deriv_multiplication(root, args), self.assertEqual(const_deriv_multiplication(root, args),
l2 * der(x, x)) l2 * der(x, x))
self.assertRewrite(["[2x]'", "2[x]'", '2 * 1', '2'])
def test_match_variable_power(self): def test_match_variable_power(self):
root, x, l2 = tree('d/dx x ^ 2, x, 2') root, x, l2 = tree('d/dx x ^ 2, x, 2')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment