Explorar o código

Added a bunch of square root rewrite rules.

Taddeus Kroes %!s(int64=14) %!d(string=hai) anos
pai
achega
746d7957cb
Modificáronse 4 ficheiros con 232 adicións e 2 borrados
  1. 3 1
      src/rules/__init__.py
  2. 9 1
      src/rules/precedences.py
  3. 120 0
      src/rules/sqrt.py
  4. 100 0
      tests/test_rules_sqrt.py

+ 3 - 1
src/rules/__init__.py

@@ -1,5 +1,5 @@
 from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_NEG, OP_SIN, OP_COS, \
-        OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS
+        OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS, OP_SQRT
 from .groups import match_combine_groups
 from .factors import match_expand
 from .powers import match_add_exponents, match_subtract_exponents, \
@@ -29,6 +29,7 @@ from .integrals import match_solve_indef, match_constant_integral, \
         match_sum_rule_integral, match_remove_indef_constant
 from .lineq import match_move_term
 from .absolute import match_factor_out_abs_term
+from .sqrt import match_reduce_sqrt
 
 
 RULES = {
@@ -65,4 +66,5 @@ RULES = {
         OP_INT_INDEF: [match_remove_indef_constant, match_solve_indef],
         OP_EQ: [match_move_term],
         OP_ABS: [match_factor_out_abs_term],
+        OP_SQRT: [match_reduce_sqrt],
         }

+ 9 - 1
src/rules/precedences.py

@@ -1,7 +1,7 @@
 from .factors import expand_double, expand_single
 from .sort import move_constant
 from .numerics import multiply_one, multiply_zero, reduce_fraction_constants, \
-        raise_numerics, remove_zero
+        raise_numerics, remove_zero, multiply_numerics
 from .logarithmic import factor_in_exponent_multiplicant, \
         factor_out_exponent, raised_base, factor_out_exponent_important
 from .derivatives import chain_rule
@@ -10,6 +10,7 @@ from .negation import double_negation, negated_factor, negated_nominator, \
 from .fractions import multiply_with_fraction
 from .integrals import factor_out_constant, integrate_variable_root
 from .powers import remove_power_of_one
+from .sqrt import quadrant_sqrt, extract_sqrt_mult_priority
 
 
 # Functions to move to the beginning of the possibilities list. Pairs of within
@@ -48,6 +49,13 @@ RELATIVE = [
         # int x dx  ->  int x ^ 1 dx  # do not remove power of one that has
         #                             # deliberately been inserted
         (integrate_variable_root, remove_power_of_one),
+
+        # When simplifying square roots, bring numeric quadrants out of the
+        # root first
+        (extract_sqrt_mult_priority, multiply_numerics),
+
+        # sqrt(2 ^ 2)  ->  2  # not sqrt 4
+        (quadrant_sqrt, raise_numerics),
         ]
 
 

+ 120 - 0
src/rules/sqrt.py

@@ -0,0 +1,120 @@
+import math
+
+from .utils import greatest_common_divisor, dividers, is_prime
+from ..node import ExpressionLeaf as Leaf, Scope, OP_SQRT, OP_MUL, sqrt
+from ..possibilities import Possibility as P, MESSAGES
+from ..translate import _
+
+
+def is_eliminateable_sqrt(n):
+    if isinstance(n, int):
+        return n > 3 and int(math.sqrt(n)) ** 2 == n
+
+    if n.negated:
+        return False
+
+    if n.is_numeric():
+        return is_eliminateable_sqrt(n.value)
+
+    return n.is_power(2)
+
+
+def match_reduce_sqrt(node):
+    """
+    sqrt(a ^ 2)  ->  a
+    sqrt(a) and eval(sqrt(a)) in Z              ->  eval(sqrt(a))
+    sqrt(a) and a == b ^ 2 * c with a,b,c in Z  ->  sqrt(eval(b ^ 2) * c)
+    sqrt(ab)     ->  sqrt(a)sqrt(b)
+    """
+    assert node.is_op(OP_SQRT)
+
+    exp = node[0]
+
+    if exp.negated:
+        return []
+
+    if exp.is_power(2):
+        return [P(node, quadrant_sqrt)]
+
+    if exp.is_numeric():
+        reduced = int(math.sqrt(exp.value))
+
+        if reduced ** 2 == exp.value:
+            return [P(node, constant_sqrt, (reduced,))]
+
+        div = filter(is_eliminateable_sqrt, dividers(exp.value))
+        div.sort(lambda a, b: cmp(is_prime(b), is_prime(a)))
+
+        return [P(node, split_dividers, (m, exp.value / m)) for m in div]
+
+    if exp.is_op(OP_MUL):
+        scope = Scope(exp)
+        p = []
+
+        for n in scope:
+            if is_eliminateable_sqrt(n):
+                p.append(P(node, extract_sqrt_mult_priority, (scope, n)))
+            else:
+                p.append(P(node, extract_sqrt_multiplicant, (scope, n)))
+
+        return p
+
+    return []
+
+
+def quadrant_sqrt(root, args):
+    """
+    sqrt(a ^ 2)  ->  a
+    """
+    return root[0][0].negate(root.negated)
+
+
+MESSAGES[quadrant_sqrt] = \
+        _('The square root of a quadrant reduces to the raised root.')
+
+
+def constant_sqrt(root, args):
+    """
+    sqrt(a) and eval(sqrt(a)) in Z  ->  eval(sqrt(a))
+    """
+    return Leaf(args[0]).negate(root.negated)
+
+
+MESSAGES[constant_sqrt] = \
+        _('The square root of {0[0]} is {1}.')
+
+
+def split_dividers(root, args):
+    """
+    sqrt(a) and b * c = a with a,b,c in Z  ->  sqrt(a * b)
+    """
+    b, c = args
+
+    return sqrt(Leaf(b) * c)
+
+
+MESSAGES[split_dividers] = _('Write {0[0]} as {1} * {2} to so that {1} can ' \
+        'be brought outside of the square root.')
+
+
+def extract_sqrt_multiplicant(root, args):
+    """
+    sqrt(ab)     ->  sqrt(a)sqrt(b)
+    """
+    scope, a = args
+    scope.remove(a)
+
+    return (sqrt(a) * sqrt(scope.as_nary_node())).negate(root.negated)
+
+
+MESSAGES[extract_sqrt_multiplicant] = _('Extract {2} from {0}.')
+
+
+def extract_sqrt_mult_priority(root, args):
+    """
+    sqrt(ab) and sqrt(a) in Z  ->  sqrt(a)sqrt(b)
+    """
+    return extract_sqrt_multiplicant(root, args)
+
+
+MESSAGES[extract_sqrt_mult_priority] = MESSAGES[extract_sqrt_multiplicant]

+ 100 - 0
tests/test_rules_sqrt.py

@@ -0,0 +1,100 @@
+from src.rules.sqrt import is_eliminateable_sqrt, match_reduce_sqrt, \
+        quadrant_sqrt, constant_sqrt, split_dividers, \
+        extract_sqrt_multiplicant, extract_sqrt_mult_priority
+from src.node import Scope, sqrt
+from src.possibilities import Possibility as P
+from tests.rulestestcase import RulesTestCase, tree
+
+
+class TestRulesSqrt(RulesTestCase):
+
+    def test_is_eliminateable_sqrt(self):
+        self.assertFalse(is_eliminateable_sqrt(3))
+        self.assertTrue(is_eliminateable_sqrt(4))
+        self.assertTrue(is_eliminateable_sqrt(9))
+        self.assertTrue(is_eliminateable_sqrt(tree('9')))
+        self.assertFalse(is_eliminateable_sqrt(tree('-9')))
+        self.assertFalse(is_eliminateable_sqrt(tree('5')))
+        self.assertTrue(is_eliminateable_sqrt(tree('a ^ 2')))
+        self.assertFalse(is_eliminateable_sqrt(tree('a ^ 3')))
+        self.assertFalse(is_eliminateable_sqrt(tree('a')))
+
+    def test_match_reduce_sqrt_none(self):
+        root = tree('sqrt(a)')
+        self.assertEqualPos(match_reduce_sqrt(root), [])
+
+        root = tree('sqrt(-4)')
+        self.assertEqualPos(match_reduce_sqrt(root), [])
+
+    def test_match_reduce_sqrt_quadrant(self):
+        root = tree('sqrt(a ^ 2)')
+        self.assertEqualPos(match_reduce_sqrt(root), [P(root, quadrant_sqrt)])
+
+    def test_match_reduce_sqrt_constant(self):
+        root = tree('sqrt(4)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, constant_sqrt, (2,))])
+
+    def test_match_reduce_sqrt_dividers(self):
+        root = tree('sqrt(8)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, split_dividers, (4, 2))])
+
+        root = tree('sqrt(27)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, split_dividers, (9, 3))])
+
+    def test_match_reduce_sqrt_mult_priority(self):
+        root = tree('sqrt(9 * 3)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, extract_sqrt_mult_priority, (Scope(root[0]), 9)),
+                 P(root, extract_sqrt_multiplicant, (Scope(root[0]), 3))])
+
+    def test_match_reduce_sqrt_mult(self):
+        ((l2, x),) = root = tree('sqrt(2x)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, extract_sqrt_multiplicant, (Scope(root[0]), l2)),
+                 P(root, extract_sqrt_multiplicant, (Scope(root[0]), x))])
+
+        (((l2, x), y),) = root = tree('sqrt(2xy)')
+        self.assertEqualPos(match_reduce_sqrt(root),
+                [P(root, extract_sqrt_multiplicant, (Scope(root[0]), l2)),
+                 P(root, extract_sqrt_multiplicant, (Scope(root[0]), x)),
+                 P(root, extract_sqrt_multiplicant, (Scope(root[0]), y))])
+
+    def test_quadrant_sqrt(self):
+        root, expect = tree('sqrt(a ^ 2), a')
+        self.assertEqual(quadrant_sqrt(root, ()), expect)
+
+        root, expect = tree('-sqrt(a ^ 2), -a')
+        self.assertEqual(quadrant_sqrt(root, ()), expect)
+
+    def test_constant_sqrt(self):
+        root = tree('sqrt(4)')
+        self.assertEqual(constant_sqrt(root, (2,)), 2)
+
+    def test_split_dividers(self):
+        root, expect = tree('sqrt(27), sqrt(9 * 3)')
+        self.assertEqual(split_dividers(root, (9, 3)), expect)
+
+    def test_extract_sqrt_multiplicant(self):
+        root, expect = tree('sqrt(2x), sqrt(2)sqrt(x)')
+        l2, x = mul = root[0]
+        self.assertEqual(extract_sqrt_multiplicant(root, (Scope(mul), l2,)),
+                         expect)
+
+        root, expect = tree('-sqrt(2x), -sqrt(2)sqrt(x)')
+        l2, x = mul = root[0]
+        self.assertEqual(extract_sqrt_multiplicant(root, (Scope(mul), l2,)),
+                         expect)
+
+        root, expect = tree('sqrt(2xy), sqrt(x)sqrt(2y)')
+        (l2, x), y = mul = root[0]
+        self.assertEqual(extract_sqrt_multiplicant(root, (Scope(mul), x,)),
+                         expect)
+
+    def test_extract_sqrt_mult_priority(self):
+        root, expect = tree('sqrt(9 * 3), sqrt(9)sqrt(3)')
+        l9, l3 = mul = root[0]
+        self.assertEqual(extract_sqrt_mult_priority(root, (Scope(mul), l9,)),
+                         expect)