Parcourir la source

Added some derivative rules.

Taddeus Kroes il y a 14 ans
Parent
commit
307006701b
3 fichiers modifiés avec 157 ajouts et 0 suppressions
  1. 104 0
      src/rules/derivatives.py
  2. 13 0
      src/rules/utils.py
  3. 40 0
      tests/test_rules_derivatives.py

+ 104 - 0
src/rules/derivatives.py

@@ -0,0 +1,104 @@
+from itertools import combinations
+
+from .utils import find_variables
+from ..node import Scope, OP_DERIV, ExpressionNode as N, ExpressionLeaf as L
+from ..possibilities import Possibility as P, MESSAGES
+from ..translate import _
+
+
+def der(f, x=None):
+    return N('der', f, x) if x else N('der', f)
+
+
+def get_derivation_variable(node, variables=None):
+    """
+    Find the variable to derive over.
+
+    >>> print get_derivation_variable(der(L('x')))
+    'x'
+    """
+    if len(node) > 1:
+        assert node[1].is_identifier()
+        return node[1].value
+
+    if not variables:
+        variables = find_variables(node)
+
+    if len(variables) > 1:
+        # FIXME: Use first variable, sorted alphabetically?
+        #return sorted(variables)[0]
+        raise ValueError('More than 1 variable in implicit derivative: '
+                         + ', '.join(variables))
+
+    if not len(variables):
+        return None
+
+    return list(variables)[0]
+
+
+def match_constant_derivative(node):
+    """
+    der(x)     ->  1
+    der(x, x)  ->  1
+    der(x, y)  ->  x
+    der(n)     ->  0
+    """
+    assert node.is_op(OP_DERIV)
+
+    variables = find_variables(node[0])
+    var = get_derivation_variable(node, variables=variables)
+
+    if not var or var not in variables:
+        return [P(node, zero_derivative, ())]
+
+    if (node[0] == node[1] if len(node) > 1 else node[0].is_variable()):
+        return [P(node, one_derivative, ())]
+
+    return []
+
+
+def one_derivative(root, args):
+    """
+    der(x)     ->  1
+    der(x, x)  ->  1
+    """
+    return L(1)
+
+
+MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
+
+
+def zero_derivative(root, args):
+    """
+    der(n)  ->  0
+    """
+    return L(0)
+
+
+MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
+
+
+def match_variable_power(node):
+    """
+    der(x ^ n)     ->  n * x ^ (n - 1)
+    der(x ^ n, x)  ->  n * x ^ (n - 1)
+    der(x ^ f(x))  ->  n * x ^ (n - 1)
+    """
+    assert node.is_op(OP_DERIV)
+
+    if node[0].is_power():
+        x, n = node[0]
+
+        if x.is_variable():
+            return [P(node, variable_power, ())]
+
+    return []
+
+
+def variable_power(root, args):
+    """
+    der(x ^ n, x)  ->  n * x ^ (n - 1)
+    """
+    x, n = args
+
+    return n * x ^ (n - 1)

+ 13 - 0
src/rules/utils.py

@@ -71,3 +71,16 @@ def partition(callback, iterable):
         (a if callback(item) else b).append(item)
 
     return a, b
+
+
+def find_variables(node):
+    """
+    Find all variables in a node.
+    """
+    if node.is_variable():
+        return set([node.value])
+
+    if not node.is_leaf:
+        return reduce(lambda a, b: a | b, map(find_variables, node))
+
+    return set()

+ 40 - 0
tests/test_rules_derivatives.py

@@ -0,0 +1,40 @@
+from src.rules.derivatives import get_derivation_variable, \
+        match_constant_derivative, one_derivative, zero_derivative
+from src.possibilities import Possibility as P
+from tests.rulestestcase import RulesTestCase, tree
+
+
+class TestRulesDerivatives(RulesTestCase):
+
+    def test_get_derivation_variable(self):
+        xy, x, l1 = tree('der(xy, x), der(x), der(1)')
+        self.assertEqual(get_derivation_variable(xy), 'x')
+        self.assertEqual(get_derivation_variable(x), 'x')
+        self.assertIsNone(get_derivation_variable(l1))
+
+        self.assertRaises(ValueError, tree, 'der(xy)')
+
+    def test_match_constant_derivative(self):
+        root = tree('der(x)')
+        self.assertEqualPos(match_constant_derivative(root),
+                [P(root, one_derivative, ())])
+
+        root = tree('der(x, x)')
+        self.assertEqualPos(match_constant_derivative(root),
+                [P(root, one_derivative, ())])
+
+        root = tree('der(x, y)')
+        self.assertEqualPos(match_constant_derivative(root),
+                [P(root, zero_derivative, ())])
+
+        root = tree('der(2)')
+        self.assertEqualPos(match_constant_derivative(root),
+                [P(root, zero_derivative, ())])
+
+    def test_one_derivative(self):
+        root = tree('der(x)')
+        self.assertEqual(one_derivative(root, ()), 1)
+
+    def test_zero_derivative(self):
+        root = tree('der(1)')
+        self.assertEqual(zero_derivative(root, ()), 0)