derivatives.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from .utils import find_variables
  2. from .logarithmic import ln
  3. from .goniometry import sin, cos
  4. from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, OP_DER, \
  5. OP_MUL, OP_LOG, OP_SIN, OP_COS, OP_TAN, OP_ADD, OP_DIV
  6. from ..possibilities import Possibility as P, MESSAGES
  7. from ..translate import _
  8. def der(f, x=None):
  9. return N(OP_DER, f, x) if x else N(OP_DER, f)
  10. def second_arg(node):
  11. """
  12. Get the second child of a node if it exists, None otherwise.
  13. """
  14. return node[1] if len(node) > 1 else None
  15. def get_derivation_variable(node, variables=None):
  16. """
  17. Find the variable to derive over.
  18. >>> print get_derivation_variable(der(L('x')))
  19. 'x'
  20. """
  21. if len(node) > 1:
  22. assert node[1].is_identifier()
  23. return node[1].value
  24. if not variables:
  25. variables = find_variables(node)
  26. if len(variables) > 1:
  27. # FIXME: Use first variable, sorted alphabetically?
  28. #return sorted(variables)[0]
  29. raise ValueError('More than 1 variable in implicit derivative: '
  30. + ', '.join(variables))
  31. if not len(variables):
  32. return None
  33. return list(variables)[0]
  34. def chain_rule(root, args):
  35. """
  36. Apply the chain rule:
  37. [f(g(x)]' -> f'(g(x)) * g'(x)
  38. f'(g(x)) is not expressable in the current syntax, so calculate it directly
  39. using the application function in the arguments. g'(x) is simply expressed
  40. as der(g(x), x).
  41. """
  42. g, f_deriv, f_deriv_args = args
  43. x = root[1] if len(root) > 1 else None
  44. return f_deriv(root, f_deriv_args) * der(g, x)
  45. def match_zero_derivative(node):
  46. """
  47. der(x, y) -> 0
  48. der(n) -> 0
  49. """
  50. assert node.is_op(OP_DER)
  51. variables = find_variables(node[0])
  52. var = get_derivation_variable(node, variables)
  53. if not var or var not in variables:
  54. return [P(node, zero_derivative)]
  55. return []
  56. def zero_derivative(root, args):
  57. """
  58. der(x, y) -> 0
  59. der(n) -> 0
  60. """
  61. return L(0)
  62. MESSAGES[zero_derivative] = _('Constant {0[0]} has derivative 0.')
  63. def match_one_derivative(node):
  64. """
  65. der(x) -> 1 # Implicit x
  66. der(x, x) -> 1 # Explicit x
  67. """
  68. assert node.is_op(OP_DER)
  69. var = get_derivation_variable(node)
  70. if var and node[0] == L(var):
  71. return [P(node, one_derivative)]
  72. return []
  73. def one_derivative(root, args):
  74. """
  75. der(x) -> 1
  76. der(x, x) -> 1
  77. """
  78. return L(1)
  79. MESSAGES[one_derivative] = _('Variable {0[0]} has derivative 1.')
  80. def match_const_deriv_multiplication(node):
  81. """
  82. der(c * f(x), x) -> c * der(f(x), x)
  83. """
  84. assert node.is_op(OP_DER)
  85. p = []
  86. if node[0].is_op(OP_MUL):
  87. x = L(get_derivation_variable(node))
  88. scope = Scope(node[0])
  89. for n in scope:
  90. if not n.contains(x):
  91. p.append(P(node, const_deriv_multiplication, (scope, n, x)))
  92. return p
  93. def const_deriv_multiplication(root, args):
  94. """
  95. der(c * f(x), x) -> c * der(f(x), x)
  96. """
  97. scope, c, x = args
  98. scope.remove(c)
  99. return c * der(scope.as_nary_node(), x)
  100. MESSAGES[const_deriv_multiplication] = \
  101. _('Bring multiplication with {2} in derivative {0} to the outside.')
  102. def match_variable_power(node):
  103. """
  104. der(x ^ n) -> n * x ^ (n - 1)
  105. der(x ^ n, x) -> n * x ^ (n - 1)
  106. der(f(x) ^ n) -> n * f(x) ^ (n - 1) * der(f(x)) # Chain rule
  107. """
  108. assert node.is_op(OP_DER)
  109. if not node[0].is_power():
  110. return []
  111. root, exponent = node[0]
  112. rvars = find_variables(root)
  113. evars = find_variables(exponent)
  114. x = get_derivation_variable(node, rvars | evars)
  115. if x in rvars and x not in evars:
  116. if root.is_variable():
  117. return [P(node, variable_root)]
  118. return [P(node, chain_rule, (root, variable_root, ()))]
  119. elif not x in rvars and x in evars:
  120. if exponent.is_variable():
  121. return [P(node, variable_exponent)]
  122. return [P(node, chain_rule, (exponent, variable_exponent, ()))]
  123. return []
  124. def variable_root(root, args):
  125. """
  126. der(x ^ n, x) -> n * x ^ (n - 1)
  127. """
  128. x, n = root[0]
  129. return n * x ** (n - 1)
  130. MESSAGES[variable_root] = \
  131. _('Apply standard derivative d/dx x ^ n = n * x ^ (n - 1) on {0}.')
  132. def variable_exponent(root, args):
  133. """
  134. der(g ^ x, x) -> g ^ x * ln(g)
  135. Note that (in combination with logarithmic/constant rules):
  136. der(e ^ x) -> e ^ x * ln(e) -> e ^ x * 1 -> e ^ x
  137. """
  138. # TODO: Put above example 'der(e ^ x)' in unit test
  139. g, x = root[0]
  140. return g ** x * ln(g)
  141. MESSAGES[variable_exponent] = \
  142. _('Apply standard derivative d/dx g ^ x = g ^ x * ln g.')
  143. def match_logarithmic(node):
  144. """
  145. der(log(x, g), x) -> 1 / (x * ln(g))
  146. der(log(f(x), g), x) -> 1 / (f(x) * ln(g)) * der(f(x), x)
  147. """
  148. assert node.is_op(OP_DER)
  149. x = get_derivation_variable(node)
  150. if x and node[0].is_op(OP_LOG):
  151. f = node[0][0]
  152. x = L(x)
  153. if f == x:
  154. return [P(node, logarithmic, ())]
  155. if f.contains(x):
  156. return [P(node, chain_rule, (f, logarithmic, ()))]
  157. return []
  158. def logarithmic(root, args):
  159. """
  160. der(log(x, g), x) -> 1 / (x * ln(g))
  161. """
  162. x, g = root[0]
  163. return L(1) / (x * ln(g))
  164. MESSAGES[logarithmic] = \
  165. _('Apply standard derivative d/dx log(x, g) = 1 / (x * ln(g)).')
  166. def match_goniometric(node):
  167. """
  168. der(sin(x), x) -> cos(x)
  169. der(sin(f(x)), x) -> cos(f(x)) * der(f(x), x)
  170. der(cos(x), x) -> -sin(x)
  171. der(cos(f(x)), x) -> -sin(f(x)) * der(f(x), x)
  172. der(tan(x), x) -> der(sin(x) / cos(x), x)
  173. """
  174. assert node.is_op(OP_DER)
  175. x = get_derivation_variable(node)
  176. if x and not node[0].is_leaf:
  177. op = node[0].op
  178. if op in (OP_SIN, OP_COS):
  179. f = node[0][0]
  180. x = L(x)
  181. handler = sinus if op == OP_SIN else cosinus
  182. if f == x:
  183. return [P(node, handler)]
  184. if f.contains(x):
  185. return [P(node, chain_rule, (f, handler, ()))]
  186. if op == OP_TAN:
  187. return [P(node, tangens)]
  188. return []
  189. def sinus(root, args):
  190. """
  191. der(sin(x), x) -> cos(x)
  192. """
  193. return cos(root[0][0])
  194. MESSAGES[sinus] = _('Apply standard derivative d/dx sin(x) = cos(x).')
  195. def cosinus(root, args):
  196. """
  197. der(cos(x), x) -> -sin(x)
  198. """
  199. return -sin(root[0][0])
  200. MESSAGES[cosinus] = _('Apply standard derivative d/dx cos(x) = -sin(x).')
  201. def tangens(root, args):
  202. """
  203. der(tan(x), x) -> der(sin(x) / cos(x), x)
  204. """
  205. x = root[0][0]
  206. return der(sin(x) / cos(x), second_arg(root))
  207. MESSAGES[tangens] = \
  208. _('Convert the tanges to a division and apply the product rule.')
  209. def match_sum_product_rule(node):
  210. """
  211. [f(x) + g(x)]' -> f'(x) + g'(x)
  212. [f(x) * g(x)]' -> f'(x) * g(x) + f(x) * g'(x)
  213. """
  214. assert node.is_op(OP_DER)
  215. x = get_derivation_variable(node)
  216. if not x or node[0].is_leaf or node[0].op not in (OP_ADD, OP_MUL):
  217. return []
  218. scope = Scope(node[0])
  219. x = L(x)
  220. functions = [n for n in scope if n.contains(x)]
  221. if len(functions) < 2:
  222. return []
  223. p = []
  224. handler = sum_rule if node[0].op == OP_ADD else product_rule
  225. for f in functions:
  226. p.append(P(node, handler, (scope, f)))
  227. return p
  228. def sum_rule(root, args):
  229. """
  230. [f(x) + g(x)]' -> f'(x) + g'(x)
  231. """
  232. scope, f = args
  233. x = second_arg(root)
  234. scope.remove(f)
  235. return der(f, x) + der(scope.as_nary_node(), x)
  236. MESSAGES[sum_rule] = _('Apply the sum rule to {0}.')
  237. def product_rule(root, args):
  238. """
  239. [f(x) * g(x)]' -> f'(x) * g(x) + f(x) * g'(x)
  240. Note that implicitely:
  241. [f(x) * g(x) * h(x)]' -> f'(x) * (g(x) * h(x)) + f(x) * [g(x) * h(x)]'
  242. """
  243. scope, f = args
  244. x = second_arg(root)
  245. scope.remove(f)
  246. gh = scope.as_nary_node()
  247. return der(f, x) * gh + f * der(gh, x)
  248. MESSAGES[product_rule] = _('Apply the product rule to {0}.')
  249. def match_quotient_rule(node):
  250. """
  251. [f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
  252. """
  253. assert node.is_op(OP_DER)
  254. x = get_derivation_variable(node)
  255. if not x or not node[0].is_op(OP_DIV):
  256. return []
  257. f, g = node[0]
  258. x = L(x)
  259. if f.contains(x) and g.contains(x):
  260. return [P(node, quotient_rule)]
  261. return []
  262. def quotient_rule(root, args):
  263. """
  264. [f(x) / g(x)]' -> (f'(x) * g(x) - f(x) * g'(x)) / g(x) ^ 2
  265. """
  266. f, g = root[0]
  267. x = second_arg(root)
  268. return (der(f, x) * g - f * der(g, x)) / g ** 2
  269. MESSAGES[quotient_rule] = _('Apply the quotient rule to {0}.')