Commit 4ffcd659 authored by Taddeus Kroes's avatar Taddeus Kroes

Added some tests for chein rule.

parent 70eeb83e
...@@ -115,11 +115,12 @@ def match_const_deriv_multiplication(node): ...@@ -115,11 +115,12 @@ def match_const_deriv_multiplication(node):
p = [] p = []
if node[0].is_op(OP_MUL): if node[0].is_op(OP_MUL):
x = L(get_derivation_variable(node))
scope = Scope(node[0]) scope = Scope(node[0])
for n in scope: for n in scope:
if n.is_numeric(): if not n.contains(x):
p.append(P(node, const_deriv_multiplication, (scope, n))) p.append(P(node, const_deriv_multiplication, (scope, n, x)))
return p return p
...@@ -128,10 +129,9 @@ def const_deriv_multiplication(root, args): ...@@ -128,10 +129,9 @@ def const_deriv_multiplication(root, args):
""" """
der(c * f(x), x) -> c * der(f(x), x) der(c * f(x), x) -> c * der(f(x), x)
""" """
scope, c = args scope, c, x = args
scope.remove(c) scope.remove(c)
x = L(get_derivation_variable(root))
return c * der(scope.as_nary_node(), x) return c * der(scope.as_nary_node(), x)
......
...@@ -49,20 +49,24 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -49,20 +49,24 @@ class TestRulesDerivatives(RulesTestCase):
root = tree('der(2x)') root = tree('der(2x)')
l2, x = root[0] l2, x = root[0]
self.assertEqualPos(match_const_deriv_multiplication(root), self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), l2))]) [P(root, const_deriv_multiplication, (Scope(root[0]), l2, x))])
(x, y), x = root = tree('der(xy, x)')
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), y, x))])
def test_match_const_deriv_multiplication_multiple_constants(self): def test_match_const_deriv_multiplication_multiple_constants(self):
root = tree('der(2x * 3)') root = tree('der(2x * 3)')
(l2, x), l3 = root[0] (l2, x), l3 = root[0]
scope = Scope(root[0]) scope = Scope(root[0])
self.assertEqualPos(match_const_deriv_multiplication(root), self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (scope, l2)), [P(root, const_deriv_multiplication, (scope, l2, x)),
P(root, const_deriv_multiplication, (scope, l3))]) P(root, const_deriv_multiplication, (scope, l3, x))])
def test_const_deriv_multiplication(self): def test_const_deriv_multiplication(self):
root = tree('der(2x)') root = tree('der(2x)')
l2, x = root[0] l2, x = root[0]
args = Scope(root[0]), l2 args = Scope(root[0]), l2, x
self.assertEqual(const_deriv_multiplication(root, args), self.assertEqual(const_deriv_multiplication(root, args),
l2 * der(x, x)) l2 * der(x, x))
...@@ -94,8 +98,14 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -94,8 +98,14 @@ class TestRulesDerivatives(RulesTestCase):
x, n = root[0] x, n = root[0]
self.assertEqual(variable_root(root, ()), n * x ** (n - 1)) self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
def test_variable_root_chain_rule(self): def test_variable_exponent(self):
pass root = tree('der(2 ^ x)')
g, x = root[0]
self.assertEqual(variable_exponent(root, ()), g ** x * ln(g))
def test_chain_rule(self): def test_chain_rule(self):
pass root = tree('der(2 ^ x ^ 3)')
l2, x3 = root[0]
x, l3 = x3
self.assertEqual(chain_rule(root, (x3, variable_exponent, ())),
l2 ** x3 * ln(l2) * der(x3))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment