derivatives.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from .utils import find_variables, first_sorted_variable
  2. from ..node import ExpressionLeaf as L, Scope, OP_DER, OP_MUL, OP_LOG, \
  3. OP_SIN, OP_COS, OP_TAN, OP_ADD, OP_DIV, E, sin, cos, der, ln
  4. from ..possibilities import Possibility as P, MESSAGES
  5. from ..translate import _
  6. def second_arg(node):
  7. """
  8. Get the second child of a node if it exists, None otherwise.
  9. """
  10. return node[1] if len(node) > 1 else None
  11. def get_derivation_variable(node, variables=None):
  12. """
  13. Find the variable to derive over.
  14. >>> print get_derivation_variable(der(L('x')))
  15. 'x'
  16. """
  17. if len(node) > 1:
  18. assert node[1].is_identifier()
  19. return node[1].value
  20. if not variables:
  21. variables = find_variables(node)
  22. if not len(variables):
  23. return None
  24. return first_sorted_variable(variables)
  25. def chain_rule(root, args):
  26. """
  27. Apply the chain rule:
  28. [f(g(x)]' -> f'(g(x)) * g'(x)
  29. f'(g(x)) is not expressable in the current syntax, so calculate it directly
  30. using the application function in the arguments. g'(x) is simply expressed
  31. as der(g(x), x).
  32. """
  33. g, f_deriv, f_deriv_args = args
  34. x = root[1] if len(root) > 1 else None
  35. return f_deriv(root, f_deriv_args) * der(g, x)
  36. def match_zero_derivative(node):
  37. """
  38. der(x, y) -> 0
  39. der(n) -> 0
  40. """
  41. assert node.is_op(OP_DER)
  42. variables = find_variables(node[0])
  43. var = get_derivation_variable(node, variables)
  44. if not var or var not in variables:
  45. return [P(node, zero_derivative)]
  46. return []
  47. def zero_derivative(root, args):
  48. """
  49. der(x, y) -> 0
  50. der(n) -> 0
  51. """
  52. return L(0)
  53. MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
  54. def match_one_derivative(node):
  55. """
  56. der(x) -> 1 # Implicit x
  57. der(x, x) -> 1 # Explicit x
  58. """
  59. assert node.is_op(OP_DER)
  60. var = get_derivation_variable(node)
  61. if var and node[0] == L(var):
  62. return [P(node, one_derivative)]
  63. return []
  64. def one_derivative(root, args):
  65. """
  66. der(x) -> 1
  67. der(x, x) -> 1
  68. """
  69. return L(1)
  70. MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
  71. def match_const_deriv_multiplication(node):
  72. """
  73. der(c * f(x), x) -> c * der(f(x), x)
  74. """
  75. assert node.is_op(OP_DER)
  76. p = []
  77. if node[0].is_op(OP_MUL):
  78. x = L(get_derivation_variable(node))
  79. scope = Scope(node[0])
  80. for n in scope:
  81. if not n.contains(x):
  82. p.append(P(node, const_deriv_multiplication, (scope, n, x)))
  83. return p
  84. def const_deriv_multiplication(root, args):
  85. """
  86. der(c * f(x), x) -> c * der(f(x), x)
  87. """
  88. scope, c, x = args
  89. scope.remove(c)
  90. return c * der(scope.as_nary_node(), x)
  91. MESSAGES[const_deriv_multiplication] = \
  92. _('Bring multiplication with {2} in derivative {0} to the outside.')
  93. def match_variable_power(node):
  94. """
  95. der(x ^ n) -> n * x ^ (n - 1)
  96. der(x ^ n, x) -> n * x ^ (n - 1)
  97. der(f(x) ^ n) -> n * f(x) ^ (n - 1) * der(f(x)) # Chain rule
  98. der(f(x) ^ g(x), x) -> der(e ^ ln(f(x) ^ g(x)), x)
  99. """
  100. assert node.is_op(OP_DER)
  101. if not node[0].is_power():
  102. return []
  103. root, exponent = node[0]
  104. rvars = find_variables(root)
  105. evars = find_variables(exponent)
  106. x = get_derivation_variable(node, rvars | evars)
  107. if x in rvars and x in evars:
  108. return [P(node, power_rule)]
  109. if x in rvars:
  110. if root.is_variable():
  111. return [P(node, variable_root)]
  112. return [P(node, chain_rule, (root, variable_root, ()))]
  113. if x in evars:
  114. if exponent.is_variable():
  115. return [P(node, variable_exponent)]
  116. return [P(node, chain_rule, (exponent, variable_exponent, ()))]
  117. return []
  118. def power_rule(root, args):
  119. """
  120. [f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]'
  121. """
  122. x = second_arg(root)
  123. return der(L(E) ** ln(root[0]), x)
  124. MESSAGES[power_rule] = \
  125. _('Write {0} as a logarithm to be able to separate root and exponent.')
  126. def variable_root(root, args):
  127. """
  128. der(x ^ n, x) -> n * x ^ (n - 1)
  129. """
  130. x, n = root[0]
  131. return n * x ** (n - 1)
  132. MESSAGES[variable_root] = \
  133. _('Apply standard derivative d/dx x ^ n = n * x ^ (n - 1) on {0}.')
  134. def variable_exponent(root, args):
  135. """
  136. der(g ^ x, x) -> g ^ x * ln(g)
  137. Note that (in combination with logarithmic/constant rules):
  138. der(e ^ x) -> e ^ x * ln(e) -> e ^ x * 1 -> e ^ x
  139. """
  140. g, x = root[0]
  141. return g ** x * ln(g)
  142. MESSAGES[variable_exponent] = \
  143. _('Apply standard derivative d/dx g ^ x = g ^ x * ln g.')
  144. def match_logarithmic(node):
  145. """
  146. der(log(x, g), x) -> 1 / (x * ln(g))
  147. der(log(f(x), g), x) -> 1 / (f(x) * ln(g)) * der(f(x), x)
  148. """
  149. assert node.is_op(OP_DER)
  150. x = get_derivation_variable(node)
  151. if x and node[0].is_op(OP_LOG):
  152. f = node[0][0]
  153. x = L(x)
  154. if f == x:
  155. return [P(node, logarithmic, ())]
  156. if f.contains(x):
  157. return [P(node, chain_rule, (f, logarithmic, ()))]
  158. return []
  159. def logarithmic(root, args):
  160. """
  161. der(log(x, g), x) -> 1 / (x * ln(g))
  162. """
  163. x, g = root[0]
  164. return L(1) / (x * ln(g))
  165. MESSAGES[logarithmic] = \
  166. _('Apply standard derivative d/dx log(x, g) = 1 / (x * ln(g)).')
  167. def match_goniometric(node):
  168. """
  169. der(sin(x), x) -> cos(x)
  170. der(sin(f(x)), x) -> cos(f(x)) * der(f(x), x)
  171. der(cos(x), x) -> -sin(x)
  172. der(cos(f(x)), x) -> -sin(f(x)) * der(f(x), x)
  173. der(tan(x), x) -> der(sin(x) / cos(x), x)
  174. """
  175. assert node.is_op(OP_DER)
  176. x = get_derivation_variable(node)
  177. if x and not node[0].is_leaf:
  178. op = node[0].op
  179. if op in (OP_SIN, OP_COS):
  180. f = node[0][0]
  181. x = L(x)
  182. handler = sinus if op == OP_SIN else cosinus
  183. if f == x:
  184. return [P(node, handler)]
  185. if f.contains(x):
  186. return [P(node, chain_rule, (f, handler, ()))]
  187. if op == OP_TAN:
  188. return [P(node, tangens)]
  189. return []
  190. def sinus(root, args):
  191. """
  192. der(sin(x), x) -> cos(x)
  193. """
  194. return cos(root[0][0])
  195. MESSAGES[sinus] = _('Apply standard derivative d/dx sin(x) = cos(x).')
  196. def cosinus(root, args):
  197. """
  198. der(cos(x), x) -> -sin(x)
  199. """
  200. return -sin(root[0][0])
  201. MESSAGES[cosinus] = _('Apply standard derivative d/dx cos(x) = -sin(x).')
  202. def tangens(root, args):
  203. """
  204. der(tan(x), x) -> der(sin(x) / cos(x), x)
  205. """
  206. x = root[0][0]
  207. return der(sin(x) / cos(x), second_arg(root))
  208. MESSAGES[tangens] = \
  209. _('Convert the tanges to a division and apply the product rule.')
  210. def match_sum_product_rule(node):
  211. """
  212. [f(x) + g(x)]' -> f'(x) + g'(x)
  213. [f(x) * g(x)]' -> f'(x) * g(x) + f(x) * g'(x)
  214. """
  215. assert node.is_op(OP_DER)
  216. x = get_derivation_variable(node)
  217. if not x or node[0].is_leaf or node[0].op not in (OP_ADD, OP_MUL):
  218. return []
  219. scope = Scope(node[0])
  220. x = L(x)
  221. functions = [n for n in scope if n.contains(x)]
  222. if len(functions) < 2:
  223. return []
  224. p = []
  225. handler = sum_rule if node[0].op == OP_ADD else product_rule
  226. for f in functions:
  227. p.append(P(node, handler, (scope, f)))
  228. return p
  229. def sum_rule(root, args):
  230. """
  231. [f(x) + g(x)]' -> f'(x) + g'(x)
  232. """
  233. scope, f = args
  234. x = second_arg(root)
  235. scope.remove(f)
  236. return der(f, x) + der(scope.as_nary_node(), x)
  237. MESSAGES[sum_rule] = _('Apply the sum rule to {0}.')
  238. def product_rule(root, args):
  239. """
  240. [f(x) * g(x)]' -> f'(x) * g(x) + f(x) * g'(x)
  241. Note that implicitely:
  242. [f(x) * g(x) * h(x)]' -> f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]'
  243. """
  244. scope, f = args
  245. x = second_arg(root)
  246. scope.remove(f)
  247. gh = scope.as_nary_node()
  248. return der(f, x) * gh + f * der(gh, x)
  249. MESSAGES[product_rule] = _('Apply the product rule to {0}.')
  250. def match_quotient_rule(node):
  251. """
  252. [f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
  253. """
  254. assert node.is_op(OP_DER)
  255. x = get_derivation_variable(node)
  256. if not x or not node[0].is_op(OP_DIV):
  257. return []
  258. f, g = node[0]
  259. x = L(x)
  260. if f.contains(x) and g.contains(x):
  261. return [P(node, quotient_rule)]
  262. return []
  263. def quotient_rule(root, args):
  264. """
  265. [f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
  266. """
  267. f, g = root[0]
  268. x = second_arg(root)
  269. return (der(f, x) * g - f * der(g, x)) / g ** 2
  270. MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.')