powers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from itertools import combinations
  2. from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
  3. OP_MUL, OP_DIV, OP_POW, OP_ADD, OP_SQRT, negate
  4. from ..possibilities import Possibility as P, MESSAGES
  5. from ..translate import _
  6. def match_add_exponents(node):
  7. """
  8. a^p * a^q -> a^(p + q)
  9. a * a^q -> a^(1 + q)
  10. a^p * a -> a^(p + 1)
  11. a * a -> a^(1 + 1)
  12. -a * a^q -> -a^(1 + q)
  13. """
  14. assert node.is_op(OP_MUL)
  15. p = []
  16. powers = {}
  17. scope = Scope(node)
  18. for n in scope:
  19. # Order powers by their roots, e.g. a^p and a^q are put in the same
  20. # list because of the mutual 'a'
  21. if n.is_identifier():
  22. s = negate(n, 0, clone=True)
  23. exponent = L(1)
  24. elif n.is_op(OP_POW):
  25. s, exponent = n
  26. else: # pragma: nocover
  27. continue
  28. s_str = str(s)
  29. if s_str in powers:
  30. powers[s_str].append((n, exponent, s))
  31. else:
  32. powers[s_str] = [(n, exponent, s)]
  33. for root, occurrences in powers.iteritems():
  34. # If a root has multiple occurences, their exponents can be added to
  35. # create a single power with that root
  36. if len(occurrences) > 1:
  37. for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
  38. p.append(P(node, add_exponents,
  39. (scope.clone(), n0, n1, a0, e1, e2)))
  40. return p
  41. def add_exponents(root, args):
  42. """
  43. a^p * a^q -> a^(p + q)
  44. """
  45. scope, n0, n1, a, p, q = args
  46. # TODO: combine exponent negations
  47. # Replace the left node with the new expression
  48. scope.replace(n0, (a ** (p + q)).negate(n0.negated + n1.negated))
  49. # Remove the right node
  50. scope.remove(n1)
  51. return scope.as_nary_node()
  52. MESSAGES[add_exponents] = _('Add the exponents of {2} and {3}.')
  53. def match_subtract_exponents(node):
  54. """
  55. a^p / a^q -> a^(p - q)
  56. a^p / a -> a^(p - 1)
  57. a / a^q -> a^(1 - q)
  58. """
  59. assert node.is_op(OP_DIV)
  60. left, right = node
  61. left_pow, right_pow = left.is_op(OP_POW), right.is_op(OP_POW)
  62. if left_pow and right_pow and left[0] == right[0]:
  63. # A power is divided by a power with the same root
  64. return [P(node, subtract_exponents, tuple(left) + (right[1],))]
  65. if left_pow and left[0] == right:
  66. # A power is divided by a its root
  67. return [P(node, subtract_exponents, tuple(left) + (1,))]
  68. if right_pow and left == right[0]:
  69. # An identifier is divided by a power of itself
  70. return [P(node, subtract_exponents, (left, 1, right[1]))]
  71. return []
  72. def subtract_exponents(root, args):
  73. """
  74. a^p / a^q -> a^(p - q)
  75. """
  76. a, p, q = args
  77. return a ** (p - q)
  78. MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
  79. def match_multiply_exponents(node):
  80. """
  81. (a^p)^q -> a^(pq)
  82. """
  83. assert node.is_op(OP_POW)
  84. left, right = node
  85. if left.is_op(OP_POW):
  86. return [P(node, multiply_exponents, tuple(left) + (right,))]
  87. return []
  88. def multiply_exponents(root, args):
  89. """
  90. (a^p)^q -> a^(pq)
  91. """
  92. a, p, q = args
  93. return a ** (p * q)
  94. MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
  95. def match_duplicate_exponent(node):
  96. """
  97. (ab)^p -> a^p * b^p
  98. """
  99. assert node.is_op(OP_POW)
  100. root, exponent = node
  101. if root.is_op(OP_MUL):
  102. return [P(node, duplicate_exponent, (list(Scope(root)), exponent))]
  103. return []
  104. def duplicate_exponent(root, args):
  105. """
  106. (ab)^p -> a^p * b^p
  107. (abc)^p -> a^p * b^p * c^p
  108. """
  109. ab, p = args
  110. result = ab[0] ** p
  111. for b in ab[1:]:
  112. result *= b ** p
  113. return result
  114. MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
  115. def match_raised_fraction(node):
  116. """
  117. (a / b) ^ p -> a^p / b^p
  118. """
  119. assert node.is_op(OP_POW)
  120. root, exponent = node
  121. if root.is_op(OP_DIV):
  122. return [P(node, raised_fraction, (root, exponent))]
  123. return []
  124. def raised_fraction(root, args):
  125. """
  126. (a / b) ^ p -> a^p / b^p
  127. """
  128. (a, b), p = args
  129. return a ** p / b ** p
  130. MESSAGES[raised_fraction] = _('Apply the exponent {2} to the nominator and'
  131. ' denominator of fraction {1}.')
  132. def match_remove_negative_exponent(node):
  133. """
  134. a ^ -p -> 1 / a ^ p
  135. """
  136. assert node.is_op(OP_POW)
  137. a, p = node
  138. if p.negated:
  139. return [P(node, remove_negative_exponent, (a, p))]
  140. return []
  141. def remove_negative_exponent(root, args):
  142. """
  143. a^-p -> 1 / a^p
  144. """
  145. a, p = args
  146. return L(1) / a ** p.reduce_negation()
  147. MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
  148. def match_exponent_to_root(node):
  149. """
  150. a^(1 / m) -> sqrt(a, m)
  151. a^(n / m) -> sqrt(a^n, m)
  152. """
  153. assert node.is_op(OP_POW)
  154. left, right = node
  155. if right.is_op(OP_DIV):
  156. return [P(node, exponent_to_root, (left,) + tuple(right))]
  157. return []
  158. def exponent_to_root(root, args):
  159. """
  160. a^(1 / m) -> sqrt(a, m)
  161. a^(n / m) -> sqrt(a^n, m)
  162. """
  163. a, n, m = args
  164. return N(OP_SQRT, a if n == 1 else a ** n, m)
  165. def match_extend_exponent(node):
  166. """
  167. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  168. """
  169. assert node.is_op(OP_POW)
  170. left, right = node
  171. if right.is_numeric():
  172. for n in Scope(node):
  173. if n.is_op(OP_ADD):
  174. return [P(node, extend_exponent, (left, right))]
  175. return []
  176. def extend_exponent(root, args):
  177. """
  178. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  179. """
  180. left, right = args
  181. if right.value > 2:
  182. return left * left ** L(right.value - 1)
  183. return left * left
  184. def match_constant_exponent(node):
  185. """
  186. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  187. """
  188. assert node.is_op(OP_POW)
  189. exponent = node[1]
  190. if exponent == 0:
  191. return [P(node, remove_power_of_zero, ())]
  192. if exponent == 1:
  193. return [P(node, remove_power_of_one, ())]
  194. return []
  195. def remove_power_of_zero(root, args):
  196. """
  197. a ^ 0 -> 1
  198. """
  199. return L(1)
  200. MESSAGES[remove_power_of_zero] = _('Power of zero {0} rewrites to `1`.')
  201. def remove_power_of_one(root, args):
  202. """
  203. a ^ 1 -> a
  204. """
  205. return root[0]
  206. MESSAGES[remove_power_of_one] = _('Remove the power of one in {0}.')