Commit 4ffcd659 authored by Taddeus Kroes's avatar Taddeus Kroes

Added some tests for chein rule.

parent 70eeb83e
......@@ -115,11 +115,12 @@ def match_const_deriv_multiplication(node):
p = []
if node[0].is_op(OP_MUL):
x = L(get_derivation_variable(node))
scope = Scope(node[0])
for n in scope:
if n.is_numeric():
p.append(P(node, const_deriv_multiplication, (scope, n)))
if not n.contains(x):
p.append(P(node, const_deriv_multiplication, (scope, n, x)))
return p
......@@ -128,10 +129,9 @@ def const_deriv_multiplication(root, args):
"""
der(c * f(x), x) -> c * der(f(x), x)
"""
scope, c = args
scope, c, x = args
scope.remove(c)
x = L(get_derivation_variable(root))
return c * der(scope.as_nary_node(), x)
......
......@@ -49,20 +49,24 @@ class TestRulesDerivatives(RulesTestCase):
root = tree('der(2x)')
l2, x = root[0]
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), l2))])
[P(root, const_deriv_multiplication, (Scope(root[0]), l2, x))])
(x, y), x = root = tree('der(xy, x)')
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), y, x))])
def test_match_const_deriv_multiplication_multiple_constants(self):
root = tree('der(2x * 3)')
(l2, x), l3 = root[0]
scope = Scope(root[0])
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (scope, l2)),
P(root, const_deriv_multiplication, (scope, l3))])
[P(root, const_deriv_multiplication, (scope, l2, x)),
P(root, const_deriv_multiplication, (scope, l3, x))])
def test_const_deriv_multiplication(self):
root = tree('der(2x)')
l2, x = root[0]
args = Scope(root[0]), l2
args = Scope(root[0]), l2, x
self.assertEqual(const_deriv_multiplication(root, args),
l2 * der(x, x))
......@@ -94,8 +98,14 @@ class TestRulesDerivatives(RulesTestCase):
x, n = root[0]
self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
def test_variable_root_chain_rule(self):
pass
def test_variable_exponent(self):
root = tree('der(2 ^ x)')
g, x = root[0]
self.assertEqual(variable_exponent(root, ()), g ** x * ln(g))
def test_chain_rule(self):
pass
root = tree('der(2 ^ x ^ 3)')
l2, x3 = root[0]
x, l3 = x3
self.assertEqual(chain_rule(root, (x3, variable_exponent, ())),
l2 ** x3 * ln(l2) * der(x3))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment