| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- from itertools import combinations
- from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
- OP_MUL, OP_DIV, OP_POW, OP_ADD
- from ..possibilities import Possibility as P, MESSAGES
- from ..translate import _
- def match_add_exponents(node):
- """
- a^p * a^q -> a^(p + q)
- a * a^q -> a^(1 + q)
- a^p * a -> a^(p + 1)
- a * a -> a^(1 + 1)
- """
- assert node.is_op(OP_MUL)
- p = []
- powers = {}
- scope = Scope(node)
- for n in scope:
- if n.is_identifier():
- s = n
- exponent = L(1)
- elif n.is_op(OP_POW):
- # Order powers by their roots, e.g. a^p and a^q are put in the same
- # list because of the mutual 'a'
- s, exponent = n
- else: # pragma: nocover
- continue
- s_str = str(s)
- if s_str in powers:
- powers[s_str].append((n, exponent, s))
- else:
- powers[s_str] = [(n, exponent, s)]
- for root, occurrences in powers.iteritems():
- # If a root has multiple occurences, their exponents can be added to
- # create a single power with that root
- if len(occurrences) > 1:
- for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
- p.append(P(node, add_exponents, (scope, n0, n1, a0, e1, e2)))
- return p
- def add_exponents(root, args):
- """
- a^p * a^q -> a^(p + q)
- """
- scope, n0, n1, a, p, q = args
- # Replace the left node with the new expression
- scope.replace(n0, a ** (p + q))
- # Remove the right node
- scope.remove(n1)
- return scope.as_nary_node()
- MESSAGES[add_exponents] = _('Add the exponents of {2} and {3}.')
- def match_subtract_exponents(node):
- """
- a^p / a^q -> a^(p - q)
- a^p / a -> a^(p - 1)
- a / a^q -> a^(1 - q)
- """
- assert node.is_op(OP_DIV)
- left, right = node
- left_pow, right_pow = left.is_op(OP_POW), right.is_op(OP_POW)
- if left_pow and right_pow and left[0] == right[0]:
- # A power is divided by a power with the same root
- return [P(node, subtract_exponents, tuple(left) + (right[1],))]
- if left_pow and left[0] == right:
- # A power is divided by a its root
- return [P(node, subtract_exponents, tuple(left) + (1,))]
- if right_pow and left == right[0]:
- # An identifier is divided by a power of itself
- return [P(node, subtract_exponents, (left, 1, right[1]))]
- return []
- def subtract_exponents(root, args):
- """
- a^p / a^q -> a^(p - q)
- """
- a, p, q = args
- return a ** (p - q)
- MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
- def match_multiply_exponents(node):
- """
- (a^p)^q -> a^(pq)
- """
- assert node.is_op(OP_POW)
- left, right = node
- if left.is_op(OP_POW):
- return [P(node, multiply_exponents, tuple(left) + (right,))]
- return []
- def multiply_exponents(root, args):
- """
- (a^p)^q -> a^(pq)
- """
- a, p, q = args
- return a ** (p * q)
- MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
- def match_duplicate_exponent(node):
- """
- (ab)^p -> a^p * b^p
- """
- assert node.is_op(OP_POW)
- left, right = node
- if left.is_op(OP_MUL):
- return [P(node, duplicate_exponent, (list(Scope(left)), right))]
- return []
- def duplicate_exponent(root, args):
- """
- (ab)^p -> a^p * b^p
- (abc)^p -> a^p * b^p * c^p
- """
- ab, p = args
- result = ab[0] ** p
- for b in ab[1:]:
- result *= b ** p
- return result
- MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
- def match_remove_negative_exponent(node):
- """
- a ^ -p -> 1 / a ^ p
- """
- assert node.is_op(OP_POW)
- a, p = node
- if p.negated:
- return [P(node, remove_negative_exponent, (a, p))]
- return []
- def remove_negative_exponent(root, args):
- """
- a^-p -> 1 / a^p
- """
- a, p = args
- return L(1) / a ** p.reduce_negation()
- MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
- def match_exponent_to_root(node):
- """
- a^(1 / m) -> sqrt(a, m)
- a^(n / m) -> sqrt(a^n, m)
- """
- assert node.is_op(OP_POW)
- left, right = node
- if right.is_op(OP_DIV):
- return [P(node, exponent_to_root, (left,) + tuple(right))]
- return []
- def exponent_to_root(root, args):
- """
- a^(1 / m) -> sqrt(a, m)
- a^(n / m) -> sqrt(a^n, m)
- """
- a, n, m = args
- return N('sqrt', a if n == 1 else a ** n, m)
- def match_extend_exponent(node):
- """
- (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
- """
- assert node.is_op(OP_POW)
- left, right = node
- if right.is_numeric():
- for n in Scope(node):
- if n.is_op(OP_ADD):
- return [P(node, extend_exponent, (left, right))]
- return []
- def extend_exponent(root, args):
- """
- (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
- """
- left, right = args
- if right.value > 2:
- return left * left ** L(right.value - 1)
- return left * left
- def match_constant_exponent(node):
- """
- (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
- """
- assert node.is_op(OP_POW)
- exponent = node[1]
- if exponent == 0:
- return [P(node, remove_power_of_zero, ())]
- if exponent == 1:
- return [P(node, remove_power_of_one, ())]
- return []
- def remove_power_of_zero(root, args):
- """
- a ^ 0 -> 1
- """
- return L(1)
- MESSAGES[remove_power_of_zero] = _('Power of zero {0} rewrites to 1.')
- def remove_power_of_one(root, args):
- """
- a ^ 1 -> a
- """
- return root[0]
- MESSAGES[remove_power_of_one] = _('Remove the power of one in {0}.')
|