derivatives.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from itertools import combinations
  2. from .utils import find_variables
  3. from .logarithmic import ln
  4. from .goniometry import sin, cos
  5. from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \
  6. OP_MUL, OP_LOG, OP_SIN, OP_COS, OP_TAN
  7. from ..possibilities import Possibility as P, MESSAGES
  8. from ..translate import _
  9. def der(f, x=None):
  10. return N('der', f, x) if x else N('der', f)
  11. def get_derivation_variable(node, variables=None):
  12. """
  13. Find the variable to derive over.
  14. >>> print get_derivation_variable(der(L('x')))
  15. 'x'
  16. """
  17. if len(node) > 1:
  18. assert node[1].is_identifier()
  19. return node[1].value
  20. if not variables:
  21. variables = find_variables(node)
  22. if len(variables) > 1:
  23. # FIXME: Use first variable, sorted alphabetically?
  24. #return sorted(variables)[0]
  25. raise ValueError('More than 1 variable in implicit derivative: '
  26. + ', '.join(variables))
  27. if not len(variables):
  28. return None
  29. return list(variables)[0]
  30. def chain_rule(root, args):
  31. """
  32. Apply the chain rule:
  33. [f(g(x)]' -> f'(g(x)) * g'(x)
  34. f'(g(x)) is not expressable in the current syntax, so calculate it directly
  35. using the application function in the arguments. g'(x) is simply expressed
  36. as der(g(x), x).
  37. """
  38. g, f_deriv, f_deriv_args = args
  39. x = root[1] if len(root) > 1 else None
  40. return f_deriv(root, f_deriv_args) * der(g, x)
  41. def match_zero_derivative(node):
  42. """
  43. der(x, y) -> 0
  44. der(n) -> 0
  45. """
  46. assert node.is_op(OP_DER)
  47. variables = find_variables(node[0])
  48. var = get_derivation_variable(node, variables)
  49. if not var or var not in variables:
  50. return [P(node, zero_derivative)]
  51. return []
  52. def match_one_derivative(node):
  53. """
  54. der(x) -> 1 # Implicit x
  55. der(x, x) -> 1 # Explicit x
  56. """
  57. assert node.is_op(OP_DER)
  58. var = get_derivation_variable(node)
  59. if var and node[0] == L(var):
  60. return [P(node, one_derivative)]
  61. return []
  62. def one_derivative(root, args):
  63. """
  64. der(x) -> 1
  65. der(x, x) -> 1
  66. """
  67. return L(1)
  68. MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
  69. def zero_derivative(root, args):
  70. """
  71. der(x, y) -> 0
  72. der(n) -> 0
  73. """
  74. return L(0)
  75. MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
  76. def match_const_deriv_multiplication(node):
  77. """
  78. der(c * f(x), x) -> c * der(f(x), x)
  79. """
  80. assert node.is_op(OP_DER)
  81. p = []
  82. if node[0].is_op(OP_MUL):
  83. x = L(get_derivation_variable(node))
  84. scope = Scope(node[0])
  85. for n in scope:
  86. if not n.contains(x):
  87. p.append(P(node, const_deriv_multiplication, (scope, n, x)))
  88. return p
  89. def const_deriv_multiplication(root, args):
  90. """
  91. der(c * f(x), x) -> c * der(f(x), x)
  92. """
  93. scope, c, x = args
  94. scope.remove(c)
  95. return c * der(scope.as_nary_node(), x)
  96. MESSAGES[const_deriv_multiplication] = \
  97. _('Bring multiplication with {2} in derivative {0} to the outside.')
  98. def match_variable_power(node):
  99. """
  100. der(x ^ n) -> n * x ^ (n - 1)
  101. der(x ^ n, x) -> n * x ^ (n - 1)
  102. der(f(x) ^ n) -> n * f(x) ^ (n - 1) * der(f(x)) # Chain rule
  103. """
  104. assert node.is_op(OP_DER)
  105. if not node[0].is_power():
  106. return []
  107. root, exponent = node[0]
  108. rvars = find_variables(root)
  109. evars = find_variables(exponent)
  110. x = get_derivation_variable(node, rvars | evars)
  111. if x in rvars and x not in evars:
  112. if root.is_variable():
  113. return [P(node, variable_root)]
  114. return [P(node, chain_rule, (root, variable_root, ()))]
  115. elif not x in rvars and x in evars:
  116. if exponent.is_variable():
  117. return [P(node, variable_exponent)]
  118. return [P(node, chain_rule, (exponent, variable_exponent, ()))]
  119. return []
  120. def variable_root(root, args):
  121. """
  122. der(x ^ n, x) -> n * x ^ (n - 1)
  123. """
  124. x, n = root[0]
  125. return n * x ** (n - 1)
  126. MESSAGES[variable_root] = \
  127. _('Apply standard derivative d/dx x ^ n = n * x ^ (n - 1) on {0}.')
  128. def variable_exponent(root, args):
  129. """
  130. der(g ^ x, x) -> g ^ x * ln(g)
  131. Note that (in combination with logarithmic/constant rules):
  132. der(e ^ x) -> e ^ x * ln(e) -> e ^ x * 1 -> e ^ x
  133. """
  134. # TODO: Put above example 'der(e ^ x)' in unit test
  135. g, x = root[0]
  136. return g ** x * ln(g)
  137. MESSAGES[variable_exponent] = \
  138. _('Apply standard derivative d/dx g ^ x = g ^ x * ln g.')
  139. def match_logarithmic(node):
  140. """
  141. der(log(x, g), x) -> 1 / (x * ln(g))
  142. der(log(f(x), g), x) -> 1 / (f(x) * ln(g)) * der(f(x), x)
  143. """
  144. assert node.is_op(OP_DER)
  145. x = get_derivation_variable(node)
  146. if x and node[0].is_op(OP_LOG):
  147. f = node[0][0]
  148. x = L(x)
  149. if f == x:
  150. return [P(node, logarithmic, ())]
  151. if f.contains(x):
  152. return [P(node, chain_rule, (f, logarithmic, ()))]
  153. return []
  154. def logarithmic(root, args):
  155. """
  156. der(log(x, g), x) -> 1 / (x * ln(g))
  157. """
  158. x, g = root[0]
  159. return L(1) / (x * ln(g))
  160. MESSAGES[logarithmic] = \
  161. _('Apply standard derivative d/dx log(x, g) = 1 / (x * ln(g)).')
  162. def match_goniometric(node):
  163. """
  164. der(sin(x), x) -> cos(x)
  165. der(sin(f(x)), x) -> cos(f(x)) * der(f(x), x)
  166. der(cos(x), x) -> -sin(x)
  167. der(cos(f(x)), x) -> -sin(f(x)) * der(f(x), x)
  168. der(tan(x), x) -> der(sin(x) / cos(x), x)
  169. """
  170. assert node.is_op(OP_DER)
  171. x = get_derivation_variable(node)
  172. if x and not node[0].is_leaf:
  173. op = node[0].op
  174. if op in (OP_SIN, OP_COS):
  175. f = node[0][0]
  176. x = L(x)
  177. handler = sinus if op == OP_SIN else cosinus
  178. if f == x:
  179. return [P(node, handler)]
  180. if f.contains(x):
  181. return [P(node, chain_rule, (f, handler, ()))]
  182. if op == OP_TAN:
  183. return [P(node, tangens)]
  184. return []
  185. def sinus(root, args):
  186. """
  187. der(sin(x), x) -> cos(x)
  188. """
  189. return cos(root[0][0])
  190. MESSAGES[sinus] = _('Apply standard derivative d/dx sin(x) = cos(x).')
  191. def cosinus(root, args):
  192. """
  193. der(cos(x), x) -> -sin(x)
  194. """
  195. return -sin(root[0][0])
  196. MESSAGES[cosinus] = _('Apply standard derivative d/dx cos(x) = -sin(x).')
  197. def tangens(root, args):
  198. """
  199. der(tan(x), x) -> der(sin(x) / cos(x), x)
  200. """
  201. f = root[0][0]
  202. x = root[1] if len(root) > 1 else None
  203. return der(sin(f) / cos(f), x)
  204. MESSAGES[tangens] = \
  205. _('Convert the tanges to a division and apply the product rule.')