powers.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from itertools import combinations
  2. from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
  3. OP_MUL, OP_DIV, OP_POW, OP_ADD
  4. from ..possibilities import Possibility as P, MESSAGES
  5. from ..translate import _
  6. def match_add_exponents(node):
  7. """
  8. a^p * a^q -> a^(p + q)
  9. a * a^q -> a^(1 + q)
  10. a^p * a -> a^(p + 1)
  11. a * a -> a^(1 + 1)
  12. """
  13. assert node.is_op(OP_MUL)
  14. p = []
  15. powers = {}
  16. scope = Scope(node)
  17. for n in scope:
  18. if n.is_identifier():
  19. s = n
  20. exponent = L(1)
  21. elif n.is_op(OP_POW):
  22. # Order powers by their roots, e.g. a^p and a^q are put in the same
  23. # list because of the mutual 'a'
  24. s, exponent = n
  25. else: # pragma: nocover
  26. continue
  27. s_str = str(s)
  28. if s_str in powers:
  29. powers[s_str].append((n, exponent, s))
  30. else:
  31. powers[s_str] = [(n, exponent, s)]
  32. for root, occurrences in powers.iteritems():
  33. # If a root has multiple occurences, their exponents can be added to
  34. # create a single power with that root
  35. if len(occurrences) > 1:
  36. for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
  37. p.append(P(node, add_exponents, (scope, n0, n1, a0, e1, e2)))
  38. return p
  39. def add_exponents(root, args):
  40. """
  41. a^p * a^q -> a^(p + q)
  42. """
  43. scope, n0, n1, a, p, q = args
  44. # Replace the left node with the new expression
  45. scope.replace(n0, a ** (p + q))
  46. # Remove the right node
  47. scope.remove(n1)
  48. return scope.as_nary_node()
  49. MESSAGES[add_exponents] = _('Add the exponents of {2} and {3}.')
  50. def match_subtract_exponents(node):
  51. """
  52. a^p / a^q -> a^(p - q)
  53. a^p / a -> a^(p - 1)
  54. a / a^q -> a^(1 - q)
  55. """
  56. assert node.is_op(OP_DIV)
  57. left, right = node
  58. left_pow, right_pow = left.is_op(OP_POW), right.is_op(OP_POW)
  59. if left_pow and right_pow and left[0] == right[0]:
  60. # A power is divided by a power with the same root
  61. return [P(node, subtract_exponents, tuple(left) + (right[1],))]
  62. if left_pow and left[0] == right:
  63. # A power is divided by a its root
  64. return [P(node, subtract_exponents, tuple(left) + (1,))]
  65. if right_pow and left == right[0]:
  66. # An identifier is divided by a power of itself
  67. return [P(node, subtract_exponents, (left, 1, right[1]))]
  68. return []
  69. def subtract_exponents(root, args):
  70. """
  71. a^p / a^q -> a^(p - q)
  72. """
  73. a, p, q = args
  74. return a ** (p - q)
  75. MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
  76. def match_multiply_exponents(node):
  77. """
  78. (a^p)^q -> a^(pq)
  79. """
  80. assert node.is_op(OP_POW)
  81. left, right = node
  82. if left.is_op(OP_POW):
  83. return [P(node, multiply_exponents, tuple(left) + (right,))]
  84. return []
  85. def multiply_exponents(root, args):
  86. """
  87. (a^p)^q -> a^(pq)
  88. """
  89. a, p, q = args
  90. return a ** (p * q)
  91. MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
  92. def match_duplicate_exponent(node):
  93. """
  94. (ab)^p -> a^p * b^p
  95. """
  96. assert node.is_op(OP_POW)
  97. left, right = node
  98. if left.is_op(OP_MUL):
  99. return [P(node, duplicate_exponent, (list(Scope(left)), right))]
  100. return []
  101. def duplicate_exponent(root, args):
  102. """
  103. (ab)^p -> a^p * b^p
  104. (abc)^p -> a^p * b^p * c^p
  105. """
  106. ab, p = args
  107. result = ab[0] ** p
  108. for b in ab[1:]:
  109. result *= b ** p
  110. return result
  111. MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
  112. def match_remove_negative_exponent(node):
  113. """
  114. a ^ -p -> 1 / a ^ p
  115. """
  116. assert node.is_op(OP_POW)
  117. a, p = node
  118. if p.negated:
  119. return [P(node, remove_negative_exponent, (a, p))]
  120. return []
  121. def remove_negative_exponent(root, args):
  122. """
  123. a^-p -> 1 / a^p
  124. """
  125. a, p = args
  126. return L(1) / a ** p.reduce_negation()
  127. MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
  128. def match_exponent_to_root(node):
  129. """
  130. a^(1 / m) -> sqrt(a, m)
  131. a^(n / m) -> sqrt(a^n, m)
  132. """
  133. assert node.is_op(OP_POW)
  134. left, right = node
  135. if right.is_op(OP_DIV):
  136. return [P(node, exponent_to_root, (left,) + tuple(right))]
  137. return []
  138. def exponent_to_root(root, args):
  139. """
  140. a^(1 / m) -> sqrt(a, m)
  141. a^(n / m) -> sqrt(a^n, m)
  142. """
  143. a, n, m = args
  144. return N('sqrt', a if n == 1 else a ** n, m)
  145. def match_extend_exponent(node):
  146. """
  147. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  148. """
  149. assert node.is_op(OP_POW)
  150. left, right = node
  151. if right.is_numeric():
  152. for n in Scope(node):
  153. if n.is_op(OP_ADD):
  154. return [P(node, extend_exponent, (left, right))]
  155. return []
  156. def extend_exponent(root, args):
  157. """
  158. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  159. """
  160. left, right = args
  161. if right.value > 2:
  162. return left * left ** L(right.value - 1)
  163. return left * left
  164. def match_constant_exponent(node):
  165. """
  166. (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
  167. """
  168. assert node.is_op(OP_POW)
  169. exponent = node[1]
  170. if exponent == 0:
  171. return [P(node, remove_power_of_zero, ())]
  172. if exponent == 1:
  173. return [P(node, remove_power_of_one, ())]
  174. return []
  175. def remove_power_of_zero(root, args):
  176. """
  177. a ^ 0 -> 1
  178. """
  179. return L(1)
  180. MESSAGES[remove_power_of_zero] = _('Power of zero {0} rewrites to 1.')
  181. def remove_power_of_one(root, args):
  182. """
  183. a ^ 1 -> a
  184. """
  185. return root[0]
  186. MESSAGES[remove_power_of_one] = _('Remove the power of one in {0}.')