Commit 41f07554 authored by Taddeus Kroes's avatar Taddeus Kroes

Added a number of derivative rewrite rules.

parent a156a2b5
......@@ -17,7 +17,9 @@ from .negation import match_negated_factor, match_negate_polynome, \
from .sort import match_sort_multiplicants
from .goniometry import match_add_quadrants, match_negated_parameter, \
match_half_pi_subtraction, match_standard_radian
from src.rules.derivatives import match_constant_derivative
from src.rules.derivatives import match_zero_derivative, \
match_one_derivative, match_variable_power, \
match_const_deriv_multiplication
RULES = {
OP_ADD: [match_add_numerics, match_add_constant_fractions,
......@@ -27,7 +29,7 @@ RULES = {
match_negated_factor, match_multiply_one,
match_sort_multiplicants, match_multiply_fractions],
OP_DIV: [match_subtract_exponents, match_divide_numerics,
match_constant_division, match_divide_fractions, \
match_constant_division, match_divide_fractions,
match_negated_division, match_equal_fraction_parts],
OP_POW: [match_multiply_exponents, match_duplicate_exponent,
match_raised_fraction, match_remove_negative_exponent,
......@@ -39,5 +41,6 @@ RULES = {
OP_COS: [match_negated_parameter, match_half_pi_subtraction,
match_standard_radian],
OP_TAN: [match_standard_radian],
OP_DERIV: [match_constant_derivative],
OP_DERIV: [match_zero_derivative, match_one_derivative,
match_variable_power, match_const_deriv_multiplication],
}
from itertools import combinations
from .utils import find_variables
from ..node import Scope, OP_DERIV, ExpressionNode as N, ExpressionLeaf as L
from .logarithmic import ln
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DERIV, \
OP_MUL
from ..possibilities import Possibility as P, MESSAGES
from ..translate import _
......@@ -36,23 +38,48 @@ def get_derivation_variable(node, variables=None):
return list(variables)[0]
def match_constant_derivative(node):
def chain_rule(root, args):
"""
der(x) -> 1
der(x, x) -> 1
der(x, y) -> x
Apply the chain rule:
[f(g(x)]' -> f'(g(x)) * g'(x)
f'(g(x)) is not expressable in the current syntax, so calculate it directly
using the application function in the arguments. g'(x) is simply expressed
as der(g(x), x).
"""
g, f_deriv, f_deriv_args = args
x = root[1] if len(root) > 1 else None
return f_deriv(root, f_deriv_args) * der(g, x)
def match_zero_derivative(node):
"""
der(x, y) -> 0
der(n) -> 0
"""
assert node.is_op(OP_DERIV)
variables = find_variables(node[0])
var = get_derivation_variable(node, variables=variables)
var = get_derivation_variable(node, variables)
if not var or var not in variables:
return [P(node, zero_derivative, ())]
return [P(node, zero_derivative)]
return []
if (node[0] == node[1] if len(node) > 1 else node[0].is_variable()):
return [P(node, one_derivative, ())]
def match_one_derivative(node):
"""
der(x) -> 1 # Implicit x
der(x, x) -> 1 # Explicit x
"""
assert node.is_op(OP_DERIV)
var = get_derivation_variable(node)
if var and node[0] == L(var):
return [P(node, one_derivative)]
return []
......@@ -70,6 +97,7 @@ MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
def zero_derivative(root, args):
"""
der(x, y) -> 0
der(n) -> 0
"""
return L(0)
......@@ -78,27 +106,89 @@ def zero_derivative(root, args):
MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
def match_const_deriv_multiplication(node):
"""
[f(c * x)]' -> c * [f(x)]'
"""
assert node.is_op(OP_DERIV)
p = []
if node[0].is_op(OP_MUL):
scope = Scope(node[0])
for n in scope:
if n.is_numeric():
p.append(P(node, const_deriv_multiplication, (scope, n)))
return p
def const_deriv_multiplication(root, args):
"""
[f(c * x)]' -> c * [f(x)]'
"""
scope, c = args
scope.remove(c)
x = L(get_derivation_variable(root))
# FIXME: is the explicit 'x' parameter necessary?
return c * der(scope.as_nary_node(), x)
MESSAGES[const_deriv_multiplication] = \
_('Bring multiplication with {2} in derivative {0} to the outside.')
def match_variable_power(node):
"""
der(x ^ n) -> n * x ^ (n - 1)
der(x ^ n, x) -> n * x ^ (n - 1)
der(x ^ f(x)) -> n * x ^ (n - 1)
der(f(x) ^ n) -> n * f(x) ^ (n - 1) * der(f(x)) # Chain rule
"""
assert node.is_op(OP_DERIV)
if node[0].is_power():
x, n = node[0]
if not node[0].is_power():
return []
root, exponent = node[0]
rvars = find_variables(root)
evars = find_variables(exponent)
x = get_derivation_variable(node, rvars | evars)
if x.is_variable():
return [P(node, variable_power, ())]
if x in rvars and x not in evars:
if root.is_variable():
return [P(node, variable_root)]
return [P(node, chain_rule, (root, variable_root, ()))]
elif not x in rvars and x in evars:
if exponent.is_variable():
return [P(node, variable_exponent)]
return [P(node, chain_rule, (root, variable_exponent, ()))]
return []
def variable_power(root, args):
def variable_root(root, args):
"""
der(x ^ n, x) -> n * x ^ (n - 1)
"""
x, n = args
x, n = root[0]
return n * x ** (n - 1)
def variable_exponent(root, args):
"""
der(g ^ x, x) -> g ^ x * ln(g)
Note that (in combination with logarithmic/constant rules):
der(e ^ x) -> e ^ x * ln(e) -> e ^ x * 1 -> e ^ x
"""
# TODO: Put above example 'der(e ^ x)' in unit test
g, x = root[0]
return n * x ^ (n - 1)
return g ** x * ln(g)
from src.rules.derivatives import get_derivation_variable, \
match_constant_derivative, one_derivative, zero_derivative
from src.rules.derivatives import der, get_derivation_variable, \
match_zero_derivative, match_one_derivative, one_derivative, \
zero_derivative, match_variable_power, variable_root, \
match_const_deriv_multiplication, const_deriv_multiplication, \
chain_rule
from src.node import Scope
from src.possibilities import Possibility as P
from tests.rulestestcase import RulesTestCase, tree
......@@ -14,27 +18,75 @@ class TestRulesDerivatives(RulesTestCase):
self.assertRaises(ValueError, tree, 'der(xy)')
def test_match_constant_derivative(self):
root = tree('der(x)')
self.assertEqualPos(match_constant_derivative(root),
[P(root, one_derivative, ())])
root = tree('der(x, x)')
self.assertEqualPos(match_constant_derivative(root),
[P(root, one_derivative, ())])
def test_match_zero_derivative(self):
root = tree('der(x, y)')
self.assertEqualPos(match_constant_derivative(root),
[P(root, zero_derivative, ())])
self.assertEqualPos(match_zero_derivative(root),
[P(root, zero_derivative)])
root = tree('der(2)')
self.assertEqualPos(match_constant_derivative(root),
[P(root, zero_derivative, ())])
self.assertEqualPos(match_zero_derivative(root),
[P(root, zero_derivative)])
def test_zero_derivative(self):
root = tree('der(1)')
self.assertEqual(zero_derivative(root, ()), 0)
def test_match_one_derivative(self):
root = tree('der(x)')
self.assertEqualPos(match_one_derivative(root),
[P(root, one_derivative)])
root = tree('der(x, x)')
self.assertEqualPos(match_one_derivative(root),
[P(root, one_derivative)])
def test_one_derivative(self):
root = tree('der(x)')
self.assertEqual(one_derivative(root, ()), 1)
def test_zero_derivative(self):
root = tree('der(1)')
self.assertEqual(zero_derivative(root, ()), 0)
def test_match_const_deriv_multiplication(self):
root = tree('der(2x)')
l2, x = root[0]
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), l2))])
def test_match_const_deriv_multiplication_multiple_constants(self):
root = tree('der(2x * 3)')
(l2, x), l3 = root[0]
scope = Scope(root[0])
self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (scope, l2)),
P(root, const_deriv_multiplication, (scope, l3))])
def test_const_deriv_multiplication(self):
root = tree('der(2x)')
l2, x = root[0]
args = Scope(root[0]), l2
self.assertEqual(const_deriv_multiplication(root, args),
l2 * der(x, x))
def test_match_variable_power(self):
root, x, l2 = tree('der(x ^ 2), x, 2')
self.assertEqualPos(match_variable_power(root),
[P(root, variable_root)])
def test_match_variable_power_chain_rule(self):
root, x, l2, x3 = tree('der((x ^ 3) ^ 2), x, 2, x ^ 3')
self.assertEqualPos(match_variable_power(root),
[P(root, chain_rule, (x3, variable_root, ()))])
# Below is not mathematically underivable, it's just not within the
# scope of our program
root, x = tree('der(x ^ x), x')
self.assertEqualPos(match_variable_power(root), [])
def test_variable_root(self):
root = tree('der(x ^ 2)')
x, n = root[0]
self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
def test_variable_root_chain_rule(self):
pass
def test_chain_rule(self):
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment