Bladeren bron

Added a number of derivative rewrite rules.

Taddeus Kroes 14 jaren geleden
bovenliggende
commit
41f07554b9
3 gewijzigde bestanden met toevoegingen van 184 en 39 verwijderingen
  1. 6 3
      src/rules/__init__.py
  2. 108 18
      src/rules/derivatives.py
  3. 70 18
      tests/test_rules_derivatives.py

+ 6 - 3
src/rules/__init__.py

@@ -17,7 +17,9 @@ from .negation import match_negated_factor, match_negate_polynome, \
 from .sort import match_sort_multiplicants
 from .goniometry import match_add_quadrants, match_negated_parameter, \
         match_half_pi_subtraction, match_standard_radian
-from src.rules.derivatives import match_constant_derivative
+from src.rules.derivatives import match_zero_derivative, \
+        match_one_derivative, match_variable_power, \
+        match_const_deriv_multiplication
 
 RULES = {
         OP_ADD: [match_add_numerics, match_add_constant_fractions,
@@ -27,7 +29,7 @@ RULES = {
                  match_negated_factor, match_multiply_one,
                  match_sort_multiplicants, match_multiply_fractions],
         OP_DIV: [match_subtract_exponents, match_divide_numerics,
-                 match_constant_division, match_divide_fractions, \
+                 match_constant_division, match_divide_fractions,
                  match_negated_division, match_equal_fraction_parts],
         OP_POW: [match_multiply_exponents, match_duplicate_exponent,
                  match_raised_fraction, match_remove_negative_exponent,
@@ -39,5 +41,6 @@ RULES = {
         OP_COS: [match_negated_parameter, match_half_pi_subtraction,
                  match_standard_radian],
         OP_TAN: [match_standard_radian],
-        OP_DERIV: [match_constant_derivative],
+        OP_DERIV: [match_zero_derivative, match_one_derivative,
+                   match_variable_power, match_const_deriv_multiplication],
         }

+ 108 - 18
src/rules/derivatives.py

@@ -1,7 +1,9 @@
 from itertools import combinations
 
 from .utils import find_variables
-from ..node import Scope, OP_DERIV, ExpressionNode as N, ExpressionLeaf as L
+from .logarithmic import ln
+from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DERIV, \
+        OP_MUL
 from ..possibilities import Possibility as P, MESSAGES
 from ..translate import _
 
@@ -36,23 +38,48 @@ def get_derivation_variable(node, variables=None):
     return list(variables)[0]
 
 
-def match_constant_derivative(node):
+def chain_rule(root, args):
     """
-    der(x)     ->  1
-    der(x, x)  ->  1
-    der(x, y)  ->  x
+    Apply the chain rule:
+    [f(g(x)]'  ->  f'(g(x)) * g'(x)
+
+    f'(g(x)) is not expressable in the current syntax, so calculate it directly
+    using the application function in the arguments. g'(x) is simply expressed
+    as der(g(x), x).
+    """
+    g, f_deriv, f_deriv_args = args
+    x = root[1] if len(root) > 1 else None
+
+    return f_deriv(root, f_deriv_args) * der(g, x)
+
+
+def match_zero_derivative(node):
+    """
+    der(x, y)  ->  0
     der(n)     ->  0
     """
     assert node.is_op(OP_DERIV)
 
     variables = find_variables(node[0])
-    var = get_derivation_variable(node, variables=variables)
+    var = get_derivation_variable(node, variables)
 
     if not var or var not in variables:
-        return [P(node, zero_derivative, ())]
+        return [P(node, zero_derivative)]
+
+    return []
+
 
-    if (node[0] == node[1] if len(node) > 1 else node[0].is_variable()):
-        return [P(node, one_derivative, ())]
+def match_one_derivative(node):
+    """
+    der(x)     ->  1  # Implicit x
+    der(x, x)  ->  1  # Explicit x
+    """
+    assert node.is_op(OP_DERIV)
+
+    var = get_derivation_variable(node)
+
+    if var and node[0] == L(var):
+        return [P(node, one_derivative)]
 
     return []
 
@@ -70,7 +97,8 @@ MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
 
 def zero_derivative(root, args):
     """
-    der(n)  ->  0
+    der(x, y)  ->  0
+    der(n)     ->  0
     """
     return L(0)
 
@@ -78,27 +106,89 @@ def zero_derivative(root, args):
 MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
 
 
+def match_const_deriv_multiplication(node):
+    """
+    [f(c * x)]'  ->  c * [f(x)]'
+    """
+    assert node.is_op(OP_DERIV)
+
+    p = []
+
+    if node[0].is_op(OP_MUL):
+        scope = Scope(node[0])
+
+        for n in scope:
+            if n.is_numeric():
+                p.append(P(node, const_deriv_multiplication, (scope, n)))
+
+    return p
+
+
+def const_deriv_multiplication(root, args):
+    """
+    [f(c * x)]'  ->  c * [f(x)]'
+    """
+    scope, c = args
+
+    scope.remove(c)
+    x = L(get_derivation_variable(root))
+
+    # FIXME: is the explicit 'x' parameter necessary?
+    return c * der(scope.as_nary_node(), x)
+
+
+MESSAGES[const_deriv_multiplication] = \
+        _('Bring multiplication with {2} in derivative {0} to the outside.')
+
+
 def match_variable_power(node):
     """
     der(x ^ n)     ->  n * x ^ (n - 1)
     der(x ^ n, x)  ->  n * x ^ (n - 1)
-    der(x ^ f(x))  ->  n * x ^ (n - 1)
+    der(f(x) ^ n)  ->  n * f(x) ^ (n - 1) * der(f(x))  # Chain rule
     """
     assert node.is_op(OP_DERIV)
 
-    if node[0].is_power():
-        x, n = node[0]
+    if not node[0].is_power():
+        return []
+
+    root, exponent = node[0]
+
+    rvars = find_variables(root)
+    evars = find_variables(exponent)
+    x = get_derivation_variable(node, rvars | evars)
 
-        if x.is_variable():
-            return [P(node, variable_power, ())]
+    if x in rvars and x not in evars:
+        if root.is_variable():
+            return [P(node, variable_root)]
+
+        return [P(node, chain_rule, (root, variable_root, ()))]
+    elif not x in rvars and x in evars:
+        if exponent.is_variable():
+            return [P(node, variable_exponent)]
+
+        return [P(node, chain_rule, (root, variable_exponent, ()))]
 
     return []
 
 
-def variable_power(root, args):
+def variable_root(root, args):
     """
     der(x ^ n, x)  ->  n * x ^ (n - 1)
     """
-    x, n = args
+    x, n = root[0]
+
+    return n * x ** (n - 1)
+
+
+def variable_exponent(root, args):
+    """
+    der(g ^ x, x)  ->  g ^ x * ln(g)
+
+    Note that (in combination with logarithmic/constant rules):
+    der(e ^ x)  ->  e ^ x * ln(e)  ->  e ^ x * 1  ->  e ^ x
+    """
+    # TODO: Put above example 'der(e ^ x)' in unit test
+    g, x = root[0]
 
-    return n * x ^ (n - 1)
+    return g ** x * ln(g)

+ 70 - 18
tests/test_rules_derivatives.py

@@ -1,5 +1,9 @@
-from src.rules.derivatives import get_derivation_variable, \
-        match_constant_derivative, one_derivative, zero_derivative
+from src.rules.derivatives import der, get_derivation_variable, \
+        match_zero_derivative, match_one_derivative, one_derivative, \
+        zero_derivative, match_variable_power, variable_root, \
+        match_const_deriv_multiplication, const_deriv_multiplication, \
+        chain_rule
+from src.node import Scope
 from src.possibilities import Possibility as P
 from tests.rulestestcase import RulesTestCase, tree
 
@@ -14,27 +18,75 @@ class TestRulesDerivatives(RulesTestCase):
 
         self.assertRaises(ValueError, tree, 'der(xy)')
 
-    def test_match_constant_derivative(self):
-        root = tree('der(x)')
-        self.assertEqualPos(match_constant_derivative(root),
-                [P(root, one_derivative, ())])
-
-        root = tree('der(x, x)')
-        self.assertEqualPos(match_constant_derivative(root),
-                [P(root, one_derivative, ())])
-
+    def test_match_zero_derivative(self):
         root = tree('der(x, y)')
-        self.assertEqualPos(match_constant_derivative(root),
-                [P(root, zero_derivative, ())])
+        self.assertEqualPos(match_zero_derivative(root),
+                [P(root, zero_derivative)])
 
         root = tree('der(2)')
-        self.assertEqualPos(match_constant_derivative(root),
-                [P(root, zero_derivative, ())])
+        self.assertEqualPos(match_zero_derivative(root),
+                [P(root, zero_derivative)])
+
+    def test_zero_derivative(self):
+        root = tree('der(1)')
+        self.assertEqual(zero_derivative(root, ()), 0)
+
+    def test_match_one_derivative(self):
+        root = tree('der(x)')
+        self.assertEqualPos(match_one_derivative(root),
+                [P(root, one_derivative)])
+
+        root = tree('der(x, x)')
+        self.assertEqualPos(match_one_derivative(root),
+                [P(root, one_derivative)])
 
     def test_one_derivative(self):
         root = tree('der(x)')
         self.assertEqual(one_derivative(root, ()), 1)
 
-    def test_zero_derivative(self):
-        root = tree('der(1)')
-        self.assertEqual(zero_derivative(root, ()), 0)
+    def test_match_const_deriv_multiplication(self):
+        root = tree('der(2x)')
+        l2, x = root[0]
+        self.assertEqualPos(match_const_deriv_multiplication(root),
+                [P(root, const_deriv_multiplication, (Scope(root[0]), l2))])
+
+    def test_match_const_deriv_multiplication_multiple_constants(self):
+        root = tree('der(2x * 3)')
+        (l2, x), l3 = root[0]
+        scope = Scope(root[0])
+        self.assertEqualPos(match_const_deriv_multiplication(root),
+                [P(root, const_deriv_multiplication, (scope, l2)),
+                 P(root, const_deriv_multiplication, (scope, l3))])
+
+    def test_const_deriv_multiplication(self):
+        root = tree('der(2x)')
+        l2, x = root[0]
+        args = Scope(root[0]), l2
+        self.assertEqual(const_deriv_multiplication(root, args),
+                         l2 * der(x, x))
+
+    def test_match_variable_power(self):
+        root, x, l2 = tree('der(x ^ 2), x, 2')
+        self.assertEqualPos(match_variable_power(root),
+                [P(root, variable_root)])
+
+    def test_match_variable_power_chain_rule(self):
+        root, x, l2, x3 = tree('der((x ^ 3) ^ 2), x, 2, x ^ 3')
+        self.assertEqualPos(match_variable_power(root),
+                [P(root, chain_rule, (x3, variable_root, ()))])
+
+        # Below is not mathematically underivable, it's just not within the
+        # scope of our program
+        root, x = tree('der(x ^ x), x')
+        self.assertEqualPos(match_variable_power(root), [])
+
+    def test_variable_root(self):
+        root = tree('der(x ^ 2)')
+        x, n = root[0]
+        self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
+
+    def test_variable_root_chain_rule(self):
+        pass
+
+    def test_chain_rule(self):
+        pass