Commit 05c35766 authored by Taddeus Kroes's avatar Taddeus Kroes

Added rules that remove powers of 0 and 1.

parent 32519fa2
...@@ -4,7 +4,7 @@ from .factors import match_expand ...@@ -4,7 +4,7 @@ from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \ from .powers import match_add_exponents, match_subtract_exponents, \
match_multiply_exponents, match_duplicate_exponent, \ match_multiply_exponents, match_duplicate_exponent, \
match_remove_negative_exponent, match_exponent_to_root, \ match_remove_negative_exponent, match_exponent_to_root, \
match_extend_exponent match_extend_exponent, match_constant_exponent
from .numerics import match_add_numerics, match_divide_numerics, \ from .numerics import match_add_numerics, match_divide_numerics, \
match_multiply_numerics, match_multiply_zero match_multiply_numerics, match_multiply_zero
from .fractions import match_constant_division, match_add_constant_fractions, \ from .fractions import match_constant_division, match_add_constant_fractions, \
...@@ -22,6 +22,6 @@ RULES = { ...@@ -22,6 +22,6 @@ RULES = {
match_constant_division, match_negated_division], match_constant_division, match_negated_division],
OP_POW: [match_multiply_exponents, match_duplicate_exponent, OP_POW: [match_multiply_exponents, match_duplicate_exponent,
match_remove_negative_exponent, match_exponent_to_root, match_remove_negative_exponent, match_exponent_to_root,
match_extend_exponent], match_extend_exponent, match_constant_exponent],
OP_NEG: [match_negate_polynome], OP_NEG: [match_negate_polynome],
} }
...@@ -67,4 +67,4 @@ def combine_groups(root, args): ...@@ -67,4 +67,4 @@ def combine_groups(root, args):
MESSAGES[combine_groups] = \ MESSAGES[combine_groups] = \
_('Group "{2}" is multiplied by {1} and {4}, combine them.') _('Group "{3}" is multiplied by {2} and {5}, combine them.')
...@@ -17,8 +17,9 @@ def match_add_exponents(node): ...@@ -17,8 +17,9 @@ def match_add_exponents(node):
p = [] p = []
powers = {} powers = {}
scope = Scope(node)
for n in Scope(node): for n in scope:
if n.is_identifier(): if n.is_identifier():
s = n s = n
exponent = L(1) exponent = L(1)
...@@ -41,7 +42,7 @@ def match_add_exponents(node): ...@@ -41,7 +42,7 @@ def match_add_exponents(node):
# create a single power with that root # create a single power with that root
if len(occurrences) > 1: if len(occurrences) > 1:
for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2): for (n0, e1, a0), (n1, e2, a1) in combinations(occurrences, 2):
p.append(P(node, add_exponents, (n0, n1, a0, e1, e2))) p.append(P(node, add_exponents, (scope, n0, n1, a0, e1, e2)))
return p return p
...@@ -50,8 +51,7 @@ def add_exponents(root, args): ...@@ -50,8 +51,7 @@ def add_exponents(root, args):
""" """
a^p * a^q -> a^(p + q) a^p * a^q -> a^(p + q)
""" """
n0, n1, a, p, q = args scope, n0, n1, a, p, q = args
scope = Scope(root)
# Replace the left node with the new expression # Replace the left node with the new expression
scope.replace(n0, a ** (p + q)) scope.replace(n0, a ** (p + q))
...@@ -62,7 +62,7 @@ def add_exponents(root, args): ...@@ -62,7 +62,7 @@ def add_exponents(root, args):
return scope.as_nary_node() return scope.as_nary_node()
MESSAGES[add_exponents] = _('Add the exponents of {1} and {2}.') MESSAGES[add_exponents] = _('Add the exponents of {2} and {3}.')
def match_subtract_exponents(node): def match_subtract_exponents(node):
...@@ -237,3 +237,40 @@ def extend_exponent(root, args): ...@@ -237,3 +237,40 @@ def extend_exponent(root, args):
return left * left ** L(right.value - 1) return left * left ** L(right.value - 1)
return left * left return left * left
def match_constant_exponent(node):
"""
(a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
"""
assert node.is_op(OP_POW)
exponent = node[1]
if exponent == 0:
return [P(node, remove_power_of_zero, ())]
if exponent == 1:
return [P(node, remove_power_of_one, ())]
return []
def remove_power_of_zero(root, args):
"""
a ^ 0 -> 1
"""
return L(1)
MESSAGES[remove_power_of_zero] = _('Power of zero {0} rewrites to 1.')
def remove_power_of_one(root, args):
"""
a ^ 1 -> a
"""
return root[0]
MESSAGES[remove_power_of_one] = _('Remove the power of one in {0}.')
...@@ -3,9 +3,10 @@ from src.rules.powers import match_add_exponents, add_exponents, \ ...@@ -3,9 +3,10 @@ from src.rules.powers import match_add_exponents, add_exponents, \
match_multiply_exponents, multiply_exponents, \ match_multiply_exponents, multiply_exponents, \
match_duplicate_exponent, duplicate_exponent, \ match_duplicate_exponent, duplicate_exponent, \
match_remove_negative_exponent, remove_negative_exponent, \ match_remove_negative_exponent, remove_negative_exponent, \
match_exponent_to_root, exponent_to_root match_exponent_to_root, exponent_to_root, \
match_constant_exponent, remove_power_of_zero, remove_power_of_one
from src.node import Scope, ExpressionNode as N
from src.possibilities import Possibility as P from src.possibilities import Possibility as P
from src.node import ExpressionNode as N
from tests.rulestestcase import RulesTestCase, tree from tests.rulestestcase import RulesTestCase, tree
...@@ -17,7 +18,7 @@ class TestRulesPowers(RulesTestCase): ...@@ -17,7 +18,7 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root) possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, add_exponents, (n0, n1, a, p, q))]) [P(root, add_exponents, (Scope(root), n0, n1, a, p, q))])
def test_match_add_exponents_ternary(self): def test_match_add_exponents_ternary(self):
a, p, q, r = tree('a,p,q,r') a, p, q, r = tree('a,p,q,r')
...@@ -25,9 +26,9 @@ class TestRulesPowers(RulesTestCase): ...@@ -25,9 +26,9 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root) possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, add_exponents, (n0, n1, a, p, q)), [P(root, add_exponents, (Scope(root), n0, n1, a, p, q)),
P(root, add_exponents, (n0, n2, a, p, r)), P(root, add_exponents, (Scope(root), n0, n2, a, p, r)),
P(root, add_exponents, (n1, n2, a, q, r))]) P(root, add_exponents, (Scope(root), n1, n2, a, q, r))])
def test_match_add_exponents_multiple_identifiers(self): def test_match_add_exponents_multiple_identifiers(self):
a, b, p, q = tree('a,b,p,q') a, b, p, q = tree('a,b,p,q')
...@@ -35,8 +36,8 @@ class TestRulesPowers(RulesTestCase): ...@@ -35,8 +36,8 @@ class TestRulesPowers(RulesTestCase):
possibilities = match_add_exponents(root) possibilities = match_add_exponents(root)
self.assertEqualPos(possibilities, self.assertEqualPos(possibilities,
[P(root, add_exponents, (a0, a1, a, p, q)), [P(root, add_exponents, (Scope(root), a0, a1, a, p, q)),
P(root, add_exponents, (b0, b1, b, p, q))]) P(root, add_exponents, (Scope(root), b0, b1, b, p, q))])
def test_match_subtract_exponents_powers(self): def test_match_subtract_exponents_powers(self):
a, p, q = tree('a,p,q') a, p, q = tree('a,p,q')
...@@ -103,8 +104,8 @@ class TestRulesPowers(RulesTestCase): ...@@ -103,8 +104,8 @@ class TestRulesPowers(RulesTestCase):
a, p, q = tree('a,p,q') a, p, q = tree('a,p,q')
n0, n1 = root = a ** p * a ** q n0, n1 = root = a ** p * a ** q
self.assertEqualNodes(add_exponents(root, (n0, n1, a, p, q)), self.assertEqualNodes(add_exponents(root,
a ** (p + q)) (Scope(root), n0, n1, a, p, q)), a ** (p + q))
def test_subtract_exponents(self): def test_subtract_exponents(self):
a, p, q = tree('a,p,q') a, p, q = tree('a,p,q')
...@@ -147,3 +148,21 @@ class TestRulesPowers(RulesTestCase): ...@@ -147,3 +148,21 @@ class TestRulesPowers(RulesTestCase):
self.assertEqualNodes(exponent_to_root(root, (a, l1, m)), self.assertEqualNodes(exponent_to_root(root, (a, l1, m)),
N('sqrt', a, m)) N('sqrt', a, m))
def test_match_constant_exponent(self):
a0, a1, a2 = tree('a0,a1,a2')
self.assertEqualPos(match_constant_exponent(a0),
[P(a0, remove_power_of_zero, ())])
self.assertEqualPos(match_constant_exponent(a1),
[P(a1, remove_power_of_one, ())])
self.assertEqualPos(match_constant_exponent(a2), [])
def test_remove_power_of_zero(self):
self.assertEqual(remove_power_of_zero(tree('a0'), ()), 1)
def test_remove_power_of_one(self):
a1 = tree('a1')
self.assertEqual(remove_power_of_one(a1, ()), a1[0])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment