derivatives.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from itertools import combinations
  2. from .utils import find_variables
  3. from .logarithmic import ln
  4. from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DERIV, \
  5. OP_MUL
  6. from ..possibilities import Possibility as P, MESSAGES
  7. from ..translate import _
  8. def der(f, x=None):
  9. return N('der', f, x) if x else N('der', f)
  10. def get_derivation_variable(node, variables=None):
  11. """
  12. Find the variable to derive over.
  13. >>> print get_derivation_variable(der(L('x')))
  14. 'x'
  15. """
  16. if len(node) > 1:
  17. assert node[1].is_identifier()
  18. return node[1].value
  19. if not variables:
  20. variables = find_variables(node)
  21. if len(variables) > 1:
  22. # FIXME: Use first variable, sorted alphabetically?
  23. #return sorted(variables)[0]
  24. raise ValueError('More than 1 variable in implicit derivative: '
  25. + ', '.join(variables))
  26. if not len(variables):
  27. return None
  28. return list(variables)[0]
  29. def chain_rule(root, args):
  30. """
  31. Apply the chain rule:
  32. [f(g(x)]' -> f'(g(x)) * g'(x)
  33. f'(g(x)) is not expressable in the current syntax, so calculate it directly
  34. using the application function in the arguments. g'(x) is simply expressed
  35. as der(g(x), x).
  36. """
  37. g, f_deriv, f_deriv_args = args
  38. x = root[1] if len(root) > 1 else None
  39. return f_deriv(root, f_deriv_args) * der(g, x)
  40. def match_zero_derivative(node):
  41. """
  42. der(x, y) -> 0
  43. der(n) -> 0
  44. """
  45. assert node.is_op(OP_DERIV)
  46. variables = find_variables(node[0])
  47. var = get_derivation_variable(node, variables)
  48. if not var or var not in variables:
  49. return [P(node, zero_derivative)]
  50. return []
  51. def match_one_derivative(node):
  52. """
  53. der(x) -> 1 # Implicit x
  54. der(x, x) -> 1 # Explicit x
  55. """
  56. assert node.is_op(OP_DERIV)
  57. var = get_derivation_variable(node)
  58. if var and node[0] == L(var):
  59. return [P(node, one_derivative)]
  60. return []
  61. def one_derivative(root, args):
  62. """
  63. der(x) -> 1
  64. der(x, x) -> 1
  65. """
  66. return L(1)
  67. MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
  68. def zero_derivative(root, args):
  69. """
  70. der(x, y) -> 0
  71. der(n) -> 0
  72. """
  73. return L(0)
  74. MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
  75. def match_const_deriv_multiplication(node):
  76. """
  77. der(c * f(x), x) -> c * der(f(x), x)
  78. """
  79. assert node.is_op(OP_DERIV)
  80. p = []
  81. if node[0].is_op(OP_MUL):
  82. scope = Scope(node[0])
  83. for n in scope:
  84. if n.is_numeric():
  85. p.append(P(node, const_deriv_multiplication, (scope, n)))
  86. return p
  87. def const_deriv_multiplication(root, args):
  88. """
  89. der(c * f(x), x) -> c * der(f(x), x)
  90. """
  91. scope, c = args
  92. scope.remove(c)
  93. x = L(get_derivation_variable(root))
  94. return c * der(scope.as_nary_node(), x)
  95. MESSAGES[const_deriv_multiplication] = \
  96. _('Bring multiplication with {2} in derivative {0} to the outside.')
  97. def match_variable_power(node):
  98. """
  99. der(x ^ n) -> n * x ^ (n - 1)
  100. der(x ^ n, x) -> n * x ^ (n - 1)
  101. der(f(x) ^ n) -> n * f(x) ^ (n - 1) * der(f(x)) # Chain rule
  102. """
  103. assert node.is_op(OP_DERIV)
  104. if not node[0].is_power():
  105. return []
  106. root, exponent = node[0]
  107. rvars = find_variables(root)
  108. evars = find_variables(exponent)
  109. x = get_derivation_variable(node, rvars | evars)
  110. if x in rvars and x not in evars:
  111. if root.is_variable():
  112. return [P(node, variable_root)]
  113. return [P(node, chain_rule, (root, variable_root, ()))]
  114. elif not x in rvars and x in evars:
  115. if exponent.is_variable():
  116. return [P(node, variable_exponent)]
  117. return [P(node, chain_rule, (exponent, variable_exponent, ()))]
  118. return []
  119. def variable_root(root, args):
  120. """
  121. der(x ^ n, x) -> n * x ^ (n - 1)
  122. """
  123. x, n = root[0]
  124. return n * x ** (n - 1)
  125. MESSAGES[variable_root] = \
  126. _('Apply standard derivative d/dx x ^ n = n * x ^ (n - 1) on {0}.')
  127. def variable_exponent(root, args):
  128. """
  129. der(g ^ x, x) -> g ^ x * ln(g)
  130. Note that (in combination with logarithmic/constant rules):
  131. der(e ^ x) -> e ^ x * ln(e) -> e ^ x * 1 -> e ^ x
  132. """
  133. # TODO: Put above example 'der(e ^ x)' in unit test
  134. g, x = root[0]
  135. return g ** x * ln(g)
  136. MESSAGES[variable_exponent] = \
  137. _('Apply standard derivative d/dx g ^ x = g ^ x * ln g.')