numerics.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from itertools import combinations
  2. from ..node import ExpressionLeaf as Leaf, Scope, negate, OP_ADD, OP_DIV, \
  3. OP_MUL
  4. from ..possibilities import Possibility as P, MESSAGES
  5. from ..translate import _
  6. def match_add_numerics(node):
  7. """
  8. Combine two constants to a single constant in an n-ary addition.
  9. Example:
  10. 2 + 3 -> 5
  11. 2 + -3 -> -1
  12. -2 + 3 -> 1
  13. -2 + -3 -> -5
  14. """
  15. assert node.is_op(OP_ADD)
  16. p = []
  17. scope = Scope(node)
  18. numerics = filter(lambda n: n.is_numeric(), scope)
  19. for c0, c1 in combinations(numerics, 2):
  20. p.append(P(node, add_numerics, (scope, c0, c1)))
  21. return p
  22. def add_numerics(root, args):
  23. """
  24. 2 + 3 -> 5
  25. 2 + -3 -> -1
  26. -2 + 3 -> 1
  27. -2 + -3 -> -5
  28. """
  29. scope, c0, c1 = args
  30. value = c0.actual_value() + c1.actual_value()
  31. if value < 0:
  32. leaf = Leaf(-value).negate()
  33. else:
  34. leaf = Leaf(value)
  35. # Replace the left node with the new expression
  36. scope.replace(c0, Leaf(abs(value)).negate(int(value < 0)))
  37. # Remove the right node
  38. scope.remove(c1)
  39. return scope.as_nary_node()
  40. MESSAGES[add_numerics] = _('Add the constants {2} and {3}.')
  41. #def match_subtract_numerics(node):
  42. # """
  43. # 3 - 2 -> 2.0
  44. # 3.0 - 2 -> 1.0
  45. # 3 - 2.0 -> 1.0
  46. # 3.0 - 2.0 -> 1.0
  47. # """
  48. # # TODO: This should be handled by match_combine_polynomes
  49. # assert node.is_op(OP_MUL)
  50. def match_divide_numerics(node):
  51. """
  52. Combine two constants to a single constant in a division, if it does not
  53. lead to a decrease in precision.
  54. Example:
  55. 6 / 2 -> 3
  56. 3 / 2 -> 3 / 2 # 1.5 would mean a decrease in precision
  57. 3.0 / 2 -> 1.5
  58. 3 / 2.0 -> 1.5
  59. 3.0 / 2.0 -> 1.5
  60. 3 / 1.0 -> 3 # Exceptional case: division of integer by 1.0 keeps
  61. # integer precision
  62. """
  63. assert node.is_op(OP_DIV)
  64. n, d = node
  65. divide = False
  66. dv = d.value
  67. if n.is_int() and d.is_int():
  68. # 6 / 2 -> 3
  69. # 3 / 2 -> 3 / 2
  70. divide = not divmod(n.value, dv)[1]
  71. elif n.is_numeric() and d.is_numeric():
  72. if d == 1.0:
  73. # 3 / 1.0 -> 3
  74. dv = 1
  75. # 3.0 / 2 -> 1.5
  76. # 3 / 2.0 -> 1.5
  77. # 3.0 / 2.0 -> 1.5
  78. divide = True
  79. return [P(node, divide_numerics, (n.value, dv))] if divide else []
  80. def divide_numerics(root, args):
  81. """
  82. Combine two constants to a single constant in a division.
  83. Examples:
  84. 6 / 2 -> 3
  85. 3.0 / 2 -> 1.5
  86. 3 / 2.0 -> 1.5
  87. 3.0 / 2.0 -> 1.5
  88. 3 / 1.0 -> 3
  89. """
  90. n, d = args
  91. return Leaf(n / d)
  92. MESSAGES[divide_numerics] = _('Divide constant {1} by constant {2}.')
  93. def match_multiply_zero(node):
  94. """
  95. a * 0 -> 0
  96. 0 * a -> 0
  97. -0 * a -> -0
  98. 0 * -a -> -0
  99. -0 * -a -> 0
  100. """
  101. assert node.is_op(OP_MUL)
  102. left, right = node
  103. if (left.is_leaf and left.value == 0) \
  104. or (right.is_leaf and right.value == 0):
  105. return [P(node, multiply_zero, (left.negated + right.negated,))]
  106. return []
  107. def multiply_zero(root, args):
  108. """
  109. a * 0 -> 0
  110. 0 * a -> 0
  111. -0 * a -> -0
  112. 0 * -a -> -0
  113. -0 * -a -> 0
  114. """
  115. return negate(Leaf(0), args[0])
  116. MESSAGES[multiply_zero] = _('Multiplication with zero yields zero.')
  117. def match_multiply_one(node):
  118. """
  119. a * 1 -> a
  120. 1 * a -> a
  121. -1 * a -> -a
  122. 1 * -a -> -a
  123. -1 * -a -> a
  124. """
  125. assert node.is_op(OP_MUL)
  126. left, right = node
  127. if left.value == 1:
  128. return [P(node, multiply_one, (right, left))]
  129. if right.value == 1:
  130. return [P(node, multiply_one, (left, right))]
  131. return []
  132. def multiply_one(root, args):
  133. """
  134. a * 1 -> a
  135. 1 * a -> a
  136. -1 * a -> -a
  137. 1 * -a -> -a
  138. -1 * -a -> a
  139. """
  140. a, one = args
  141. return a.negate(one.negated + root.negated)
  142. MESSAGES[multiply_one] = _('Multiplication with one yields the multiplicant.')
  143. def match_multiply_numerics(node):
  144. """
  145. 3 * 2 -> 6
  146. 3.0 * 2 -> 6.0
  147. 3 * 2.0 -> 6.0
  148. 3.0 * 2.0 -> 6.0
  149. """
  150. assert node.is_op(OP_MUL)
  151. p = []
  152. scope = Scope(node)
  153. numerics = filter(lambda n: n.is_numeric(), scope)
  154. for c0, c1 in combinations(numerics, 2):
  155. p.append(P(node, multiply_numerics, (scope, c0, c1)))
  156. return p
  157. def multiply_numerics(root, args):
  158. """
  159. Combine two constants to a single constant in an n-ary multiplication.
  160. Example:
  161. 2 * 3 -> 6
  162. """
  163. scope, c0, c1 = args
  164. # Replace the left node with the new expression
  165. substitution = Leaf(c0.value * c1.value).negate(c0.negated + c1.negated)
  166. scope.replace(c0, substitution)
  167. # Remove the right node
  168. scope.remove(c1)
  169. return scope.as_nary_node()
  170. MESSAGES[multiply_numerics] = _('Multiply constant {2} with {3}.')