utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from ..node import ExpressionNode as N, ExpressionLeaf as L, OP_MUL, OP_DIV, \
  2. OP_ADD, OP_POW, OP_SQRT
  3. def greatest_common_divisor(a, b):
  4. """
  5. Return greatest common divisor of a and b using Euclid's Algorithm.
  6. """
  7. while b:
  8. a, b = b, a % b
  9. return a
  10. def lcm(a, b):
  11. """
  12. Return least common multiple of a and b.
  13. """
  14. return a * b // greatest_common_divisor(a, b)
  15. def least_common_multiple(*args):
  16. """
  17. Return lcm of args.
  18. """
  19. return reduce(lcm, args)
  20. def is_fraction(node, nominator, denominator):
  21. """
  22. Check if a node represents the fraction of a given nominator and
  23. denominator.
  24. >>> a, l1, l2 = L('a'), L(1), L(2)
  25. >>> is_fraction(a / l2, a, 2)
  26. True
  27. >>> is_fraction(l1 / l2 * a, a, 2)
  28. True
  29. >>> is_fraction(l2 / l1 * a, a, 2)
  30. False
  31. """
  32. if node.is_op(OP_DIV):
  33. nom, denom = node
  34. return nom == nominator and denom == denominator
  35. if node.is_op(OP_MUL):
  36. # 1 / denominator * nominator
  37. # nominator * 1 / denominator
  38. left, right = node
  39. fraction = L(1) / denominator
  40. return (left == nominator and right == fraction) \
  41. or (right == nominator and left == fraction)
  42. return False
  43. def partition(callback, iterable):
  44. """
  45. Partition an iterable into two parts using a callback that returns a
  46. boolean.
  47. Example:
  48. >>> partition(lambda x: x & 1, range(6))
  49. ([1, 3, 5], [0, 2, 4])
  50. """
  51. a, b = [], []
  52. for item in iterable:
  53. (a if callback(item) else b).append(item)
  54. return a, b
  55. def find_variables(node):
  56. """
  57. Find all variables in a node.
  58. """
  59. if node.is_variable():
  60. return set([node.value])
  61. if not node.is_leaf:
  62. return reduce(lambda a, b: a | b, map(find_variables, node))
  63. return set()
  64. def first_sorted_variable(variables):
  65. """
  66. In a set of variables, find the main variable to be used in a derivation or
  67. integral. The prioritized order is x, y, z, a, b, c, d, ...
  68. """
  69. for x in 'xyz':
  70. if x in variables:
  71. return x
  72. return sorted(variables)[0]
  73. def find_variable(exp):
  74. """
  75. Find the main (e.g. first prioritized) variable in an expression and return
  76. it as an ExpressionNode object. If no variable is present, return 'x' by
  77. default.
  78. """
  79. variables = find_variables(exp)
  80. if not len(variables):
  81. variables.add('x')
  82. return L(first_sorted_variable(variables))
  83. def substitute(f, x, replacement):
  84. """
  85. Replace all occurences of variable x in function f with the specified
  86. replacement.
  87. """
  88. if f == x:
  89. return replacement.clone()
  90. if f.is_leaf:
  91. return f
  92. children = map(lambda c: substitute(c, x, replacement), f)
  93. return N(f.op, *children, negated=f.negated)
  94. def divides(m, n):
  95. """
  96. Check if m | n (m divides n).
  97. """
  98. return not divmod(n, m)[1]
  99. def dividers(n):
  100. """
  101. Find all integers that divide n, except for 1.
  102. """
  103. def m_dividers(m):
  104. result, rest = divmod(n, m)
  105. if not rest:
  106. return [m, result] if m != result else [m]
  107. below_sqrt = filter(None, map(m_dividers, xrange(2, int(n ** .5) + 1)))
  108. div = reduce(lambda a, b: a + b, below_sqrt, [])
  109. div.sort()
  110. return div
  111. def is_prime(n):
  112. """
  113. Check if n is a prime.
  114. """
  115. if n == 2:
  116. return True
  117. if n < 2 or not n & 1:
  118. return False
  119. for i in xrange(3, int(n ** .5) + 1, 2):
  120. if not divmod(n, i)[1]:
  121. return False
  122. return True
  123. def prime_dividers(n):
  124. """
  125. Find all primes that divide n.
  126. """
  127. return filter(is_prime, dividers(n))
  128. def is_numeric_node(node):
  129. """
  130. Check if a node is numeric.
  131. """
  132. return node.is_numeric()
  133. def evals_to_numeric(node):
  134. """
  135. Check if a node will eventually evaluate to a numeric value, by checking if
  136. all leaves are numeric and there are only operators that can be
  137. considerered a constant or will evaluate to one (+, *, /, ^, sqrt).
  138. """
  139. if node.is_leaf:
  140. return node.is_numeric()
  141. return node.op in (OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_SQRT) \
  142. and all(map(evals_to_numeric, node))
  143. def iter_pairs(list_iterable):
  144. """
  145. Iterate over a list iterable in left-right pairs.
  146. """
  147. if len(list_iterable) < 2:
  148. raise StopIteration
  149. for i, left in enumerate(list_iterable[:-1]):
  150. yield left, list_iterable[i + 1]
  151. def range_except(start, end, exception):
  152. return range(start, exception) + range(exception + 1, end)