sqrt.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # This file is part of TRS (http://math.kompiler.org)
  2. #
  3. # TRS is free software: you can redistribute it and/or modify it under the
  4. # terms of the GNU Affero General Public License as published by the Free
  5. # Software Foundation, either version 3 of the License, or (at your option) any
  6. # later version.
  7. #
  8. # TRS is distributed in the hope that it will be useful, but WITHOUT ANY
  9. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  10. # A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  11. # details.
  12. #
  13. # You should have received a copy of the GNU Affero General Public License
  14. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
  15. import math
  16. from .utils import dividers, is_prime
  17. from ..node import ExpressionLeaf as Leaf, Scope, OP_SQRT, OP_MUL, sqrt
  18. from ..possibilities import Possibility as P, MESSAGES
  19. from ..translate import _
  20. def is_eliminateable_sqrt(n):
  21. """
  22. Check if the square root of n can be evaluated so that the square root
  23. disappears (is eliminated).
  24. """
  25. if isinstance(n, int):
  26. return n > 3 and int(math.sqrt(n)) ** 2 == n
  27. if n.negated:
  28. return False
  29. if n.is_numeric():
  30. return is_eliminateable_sqrt(n.value)
  31. return n.is_power(2)
  32. def match_reduce_sqrt(node):
  33. """
  34. sqrt(a ^ 2) -> a
  35. sqrt(a) and eval(sqrt(a)) in Z -> eval(sqrt(a))
  36. sqrt(a) and a == b ^ 2 * c with a,b,c in Z -> sqrt(eval(b ^ 2) * c)
  37. sqrt(ab) -> sqrt(a)sqrt(b)
  38. """
  39. assert node.is_op(OP_SQRT)
  40. exp = node[0]
  41. if exp.negated:
  42. return []
  43. if exp.is_power(2):
  44. return [P(node, quadrant_sqrt)]
  45. if exp.is_numeric():
  46. reduced = int(math.sqrt(exp.value))
  47. if reduced ** 2 == exp.value:
  48. return [P(node, constant_sqrt, (reduced,))]
  49. div = filter(is_eliminateable_sqrt, dividers(exp.value))
  50. div.sort(lambda a, b: cmp(is_prime(b), is_prime(a)))
  51. return [P(node, split_dividers, (m, exp.value / m)) for m in div]
  52. if exp.is_op(OP_MUL):
  53. scope = Scope(exp)
  54. p = []
  55. for n in scope:
  56. if is_eliminateable_sqrt(n):
  57. p.append(P(node, extract_sqrt_mult_priority, (scope, n)))
  58. else:
  59. p.append(P(node, extract_sqrt_multiplicant, (scope, n)))
  60. return p
  61. return []
  62. def quadrant_sqrt(root, args):
  63. """
  64. sqrt(a ^ 2) -> a
  65. """
  66. return root[0][0].negate(root.negated)
  67. MESSAGES[quadrant_sqrt] = \
  68. _('The square root of a quadrant reduces to the raised root.')
  69. def constant_sqrt(root, args):
  70. """
  71. sqrt(a) and eval(sqrt(a)) in Z -> eval(sqrt(a))
  72. """
  73. return Leaf(args[0]).negate(root.negated)
  74. MESSAGES[constant_sqrt] = \
  75. _('The square root of {0[0]} is {1}.')
  76. def split_dividers(root, args):
  77. """
  78. sqrt(a) and b * c = a with a,b,c in Z -> sqrt(a * b)
  79. """
  80. b, c = args
  81. return sqrt(Leaf(b) * c)
  82. MESSAGES[split_dividers] = _('Write {0[0]} as {1} * {2} to so that {1} can ' \
  83. 'be brought outside of the square root.')
  84. def extract_sqrt_multiplicant(root, args):
  85. """
  86. sqrt(ab) -> sqrt(a)sqrt(b)
  87. """
  88. scope, a = args
  89. scope.remove(a)
  90. return (sqrt(a) * sqrt(scope.as_nary_node())).negate(root.negated)
  91. MESSAGES[extract_sqrt_multiplicant] = _('Extract {2} from {0}.')
  92. def extract_sqrt_mult_priority(root, args):
  93. """
  94. sqrt(ab) and sqrt(a) in Z -> sqrt(a)sqrt(b)
  95. """
  96. return extract_sqrt_multiplicant(root, args)
  97. MESSAGES[extract_sqrt_mult_priority] = MESSAGES[extract_sqrt_multiplicant]