Commit d419cb36 authored by Taddeus Kroes's avatar Taddeus Kroes

Merge branch 'negated' of kompiler.org:trs

parents 8f398255 105cd269
graph_drawing @ 84ad376b
Subproject commit 11940973bdfef9432438b054c65b28af2eb97d0c Subproject commit 84ad376b81ac72e163bacd7b538df16cac9be153
...@@ -43,7 +43,7 @@ TYPE_MAP = { ...@@ -43,7 +43,7 @@ TYPE_MAP = {
OP_MAP = { OP_MAP = {
'+': OP_ADD, '+': OP_ADD,
# Either substraction or negation. Skip the operator sign in 'x' (= 2). # Either substraction or negation. Skip the operator sign in 'x' (= 2).
'-': lambda x: OP_SUB if len(x) > 2 else OP_NEG, '-': OP_SUB,
'*': OP_MUL, '*': OP_MUL,
'/': OP_DIV, '/': OP_DIV,
'^': OP_POW, '^': OP_POW,
...@@ -60,6 +60,10 @@ def to_expression(obj): ...@@ -60,6 +60,10 @@ def to_expression(obj):
class ExpressionBase(object): class ExpressionBase(object):
def __init__(self, *args, **kwargs):
self.negated = 0
def clone(self): def clone(self):
return copy.deepcopy(self) return copy.deepcopy(self)
...@@ -86,16 +90,11 @@ class ExpressionBase(object): ...@@ -86,16 +90,11 @@ class ExpressionBase(object):
if self.is_leaf: if self.is_leaf:
if other.is_leaf: if other.is_leaf:
# Both are leafs, string compare the value. # Both are leafs, string compare the value.
return str(self.value) < str(other.value) self_value = '-' * (self.negated & 1) + str(self.value)
# Self is a leaf, thus has less value than an expression node. other_value = '-' * (other.negated & 1) + str(other.value)
return True
return self_value < other_value
if self.is_op(OP_NEG) and self[0].is_leaf:
if other.is_leaf:
# Both are leafs, string compare the value.
return ('-' + str(self.value)) < str(other.value)
if other.is_op(OP_NEG) and other[0].is_leaf:
return ('-' + str(self.value)) < ('-' + str(other.value))
# Self is a leaf, thus has less value than an expression node. # Self is a leaf, thus has less value than an expression node.
return True return True
...@@ -113,24 +112,6 @@ class ExpressionBase(object): ...@@ -113,24 +112,6 @@ class ExpressionBase(object):
def is_op(self, op): def is_op(self, op):
return not self.is_leaf and self.op == op return not self.is_leaf and self.op == op
def is_op_or_negated(self, op):
if self.is_leaf:
return False
if self.op == OP_NEG:
return self[0].is_op(op)
return self.op == op
def is_leaf_or_negated(self):
if self.is_leaf:
return True
if self.is_op(OP_NEG):
return self[0].is_leaf
return False
def is_power(self): def is_power(self):
return not self.is_leaf and self.op == OP_POW return not self.is_leaf and self.op == OP_POW
...@@ -164,8 +145,13 @@ class ExpressionBase(object): ...@@ -164,8 +145,13 @@ class ExpressionBase(object):
def __pow__(self, other): def __pow__(self, other):
return ExpressionNode('^', self, to_expression(other)) return ExpressionNode('^', self, to_expression(other))
def __neg__(self): def reduce_negation(self, n=1):
return ExpressionNode('-', self) """Remove n negation flags from the node."""
return self.negate(-n)
def negate(self, n=1):
"""Negate the node n times."""
return negate(self, self.negated + n)
class ExpressionNode(Node, ExpressionBase): class ExpressionNode(Node, ExpressionBase):
...@@ -226,8 +212,10 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -226,8 +212,10 @@ class ExpressionNode(Node, ExpressionBase):
return (ExpressionLeaf(1), self[0], self[1]) return (ExpressionLeaf(1), self[0], self[1])
# rule: -r -> (1, r, 1) # rule: -r -> (1, r, 1)
if self.is_op(OP_NEG): # rule: --r -> (1, r, 1)
return (ExpressionLeaf(1), -self[0], ExpressionLeaf(1)) # rule: ---r -> (1, r, 1)
if self.negated:
return (ExpressionLeaf(1), self, ExpressionLeaf(1))
if self.op != OP_MUL: if self.op != OP_MUL:
return return
...@@ -309,7 +297,6 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -309,7 +297,6 @@ class ExpressionNode(Node, ExpressionBase):
class ExpressionLeaf(Leaf, ExpressionBase): class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs) super(ExpressionLeaf, self).__init__(*args, **kwargs)
self.type = TYPE_MAP[type(args[0])] self.type = TYPE_MAP[type(args[0])]
def __eq__(self, other): def __eq__(self, other):
...@@ -339,6 +326,11 @@ class ExpressionLeaf(Leaf, ExpressionBase): ...@@ -339,6 +326,11 @@ class ExpressionLeaf(Leaf, ExpressionBase):
# rule: 1 * r ^ 1 -> (1, r, 1) # rule: 1 * r ^ 1 -> (1, r, 1)
return (ExpressionLeaf(1), self, ExpressionLeaf(1)) return (ExpressionLeaf(1), self, ExpressionLeaf(1))
def actual_value(self):
assert self.is_numeric()
return (1 - 2 * (self.negated & 1)) * self.value
class Scope(object): class Scope(object):
...@@ -409,3 +401,11 @@ def get_scope(node): ...@@ -409,3 +401,11 @@ def get_scope(node):
scope.append(child) scope.append(child)
return scope return scope
def negate(node, n=1):
"""Negate the given node n times."""
node = node.clone()
node.negated = n
return node
...@@ -16,7 +16,7 @@ sys.path.insert(1, EXTERNAL_MODS) ...@@ -16,7 +16,7 @@ sys.path.insert(1, EXTERNAL_MODS)
from pybison import BisonParser, BisonSyntaxError from pybison import BisonParser, BisonSyntaxError
from graph_drawing.graph import generate_graph from graph_drawing.graph import generate_graph
from node import TYPE_OPERATOR, OP_COMMA from node import TYPE_OPERATOR, OP_COMMA, OP_NEG
from rules import RULES from rules import RULES
from possibilities import filter_duplicates, pick_suggestion, apply_suggestion from possibilities import filter_duplicates, pick_suggestion, apply_suggestion
...@@ -180,11 +180,13 @@ class Parser(BisonParser): ...@@ -180,11 +180,13 @@ class Parser(BisonParser):
return data return data
def hook_handler(self, target, option, names, values, retval): def hook_handler(self, target, option, names, values, retval):
if target in ['exp', 'line', 'input'] or not retval \ if target in ['exp', 'line', 'input'] or not retval:
or retval.type != TYPE_OPERATOR:
return retval return retval
if self.subtree_map: if not retval.negated and retval.type != TYPE_OPERATOR:
return retval
if self.subtree_map and retval.type == TYPE_OPERATOR:
# Update the subtree map to let the subtree point to its parent # Update the subtree map to let the subtree point to its parent
# node. # node.
parent_nodes = self.subtree_map.keys() parent_nodes = self.subtree_map.keys()
...@@ -193,10 +195,15 @@ class Parser(BisonParser): ...@@ -193,10 +195,15 @@ class Parser(BisonParser):
if child in parent_nodes: if child in parent_nodes:
self.subtree_map[child] = retval self.subtree_map[child] = retval
if retval.op not in RULES: if retval.type == TYPE_OPERATOR and retval.op in RULES:
return retval handlers = RULES[retval.op]
else:
handlers = []
if retval.negated:
handlers += RULES[OP_NEG]
for handler in RULES[retval.op]: for handler in handlers:
possibilities = handler(retval) possibilities = handler(retval)
# Record the subtree root node in order to avoid tree traversal. # Record the subtree root node in order to avoid tree traversal.
...@@ -343,7 +350,9 @@ class Parser(BisonParser): ...@@ -343,7 +350,9 @@ class Parser(BisonParser):
""" """
if option == 0: # rule: NEG exp if option == 0: # rule: NEG exp
return Node('-', values[1]) node = values[1]
node.negated += 1
return node
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
% (option, target)) # pragma: nocover % (option, target)) # pragma: nocover
...@@ -361,11 +370,8 @@ class Parser(BisonParser): ...@@ -361,11 +370,8 @@ class Parser(BisonParser):
return Node(values[1], values[0], values[2]) return Node(values[1], values[0], values[2])
if option == 4: # rule: exp MINUS exp if option == 4: # rule: exp MINUS exp
# It is necessary to call the hook_handler here explicitly, since node = values[2]
# the minus operator is internally represented as two nodes (unary node.negated += 1
# negation and binary plus).
node = Node('-', values[2])
node = self.hook_handler(target, option, names, values, node)
return Node('+', values[0], node) return Node('+', values[0], node)
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
......
from itertools import product, combinations from itertools import product, combinations
from ..node import Scope, OP_ADD, OP_MUL, OP_NEG from ..node import Scope, OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -18,7 +18,7 @@ def match_expand(node): ...@@ -18,7 +18,7 @@ def match_expand(node):
additions = [] additions = []
for n in Scope(node): for n in Scope(node):
if n.is_leaf or n.is_op(OP_NEG) and n[0].is_leaf: if n.is_leaf:
leaves.append(n) leaves.append(n)
elif n.op == OP_ADD: elif n.op == OP_ADD:
additions.append(n) additions.append(n)
......
from itertools import combinations from itertools import combinations
from .utils import least_common_multiple from .utils import least_common_multiple
from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL, OP_NEG from ..node import ExpressionLeaf as L, Scope, OP_DIV, OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -80,22 +80,11 @@ def match_add_constant_fractions(node): ...@@ -80,22 +80,11 @@ def match_add_constant_fractions(node):
p = [] p = []
def is_division(node): fractions = filter(lambda node: node.is_op(OP_DIV), Scope(node))
return node.is_op(OP_DIV) or \
(node.is_op(OP_NEG) and node[0].is_op(OP_DIV))
fractions = filter(is_division, Scope(node))
for a, b in combinations(fractions, 2): for a, b in combinations(fractions, 2):
if a.is_op(OP_NEG): na, da = a
na, da = a[0] nb, db = b
else:
na, da = a
if b.is_op(OP_NEG):
nb, db = b[0]
else:
nb, db = b
if da == db: if da == db:
# Equal denominators, add nominators to create a single fraction # Equal denominators, add nominators to create a single fraction
...@@ -116,20 +105,17 @@ def equalize_denominators(root, args): ...@@ -116,20 +105,17 @@ def equalize_denominators(root, args):
a / 2 + b / 4 -> 2a / 4 + b / 4 a / 2 + b / 4 -> 2a / 4 + b / 4
""" """
denom = args[2] denom = args[2]
scope = Scope(root) scope = Scope(root)
for fraction in args[:2]: for fraction in args[:2]:
n, d = fraction[0] if fraction.is_op(OP_NEG) else fraction n, d = fraction
mult = denom / d.value mult = denom / d.value
if mult != 1: if mult != 1:
n = L(n.value * mult) if n.is_numeric() else L(mult) * n n = L(n.value * mult) if n.is_numeric() else L(mult) * n
if fraction.is_op(OP_NEG): scope.remove(fraction, negate(n / L(d.value * mult),
scope.remove(fraction, -(n / L(d.value * mult))) fraction.negated))
else:
scope.remove(fraction, n / L(d.value * mult))
return scope.as_nary_node() return scope.as_nary_node()
...@@ -147,21 +133,11 @@ def add_nominators(root, args): ...@@ -147,21 +133,11 @@ def add_nominators(root, args):
""" """
# TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"? # TODO: is 'add' Appropriate when rewriting to "(a + (-c)) / b"?
ab, cb = args ab, cb = args
a, b = ab
if ab.is_op(OP_NEG):
a, b = ab[0]
else:
a, b = ab
if cb.is_op(OP_NEG):
c = -cb[0][0]
else:
c = cb[0]
scope = Scope(root) scope = Scope(root)
# Replace the left node with the new expression # Replace the left node with the new expression
scope.remove(ab, (a + c) / b) scope.remove(ab, (a + negate(cb[0], cb.negated)) / b)
# Remove the right node # Remove the right node
scope.remove(cb) scope.remove(cb)
......
from itertools import combinations from itertools import combinations
from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \ from ..node import ExpressionNode as Node, ExpressionLeaf as Leaf, Scope, \
OP_ADD, OP_MUL, OP_NEG OP_ADD, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -18,7 +18,6 @@ def match_combine_groups(node): ...@@ -18,7 +18,6 @@ def match_combine_groups(node):
ab + 2ab -> 3ab ab + 2ab -> 3ab
ab + ba -> 2ab ab + ba -> 2ab
""" """
# TODO: handle OP_NEG nodes
assert node.is_op(OP_ADD) assert node.is_op(OP_ADD)
p = [] p = []
...@@ -34,8 +33,7 @@ def match_combine_groups(node): ...@@ -34,8 +33,7 @@ def match_combine_groups(node):
l = len(scope) l = len(scope)
for i, sub_node in enumerate(scope): for i, sub_node in enumerate(scope):
if sub_node.is_numeric() or (sub_node.is_op(OP_NEG) if sub_node.is_numeric():
and sub_node[0].is_numeric()):
others = [scope[j] for j in range(i) + range(i + 1, l)] others = [scope[j] for j in range(i) + range(i + 1, l)]
if len(others) == 1: if len(others) == 1:
......
from ..node import get_scope, nary_node, OP_NEG, OP_ADD, OP_MUL, OP_DIV from ..node import get_scope, nary_node, OP_ADD, OP_MUL, OP_DIV
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -6,58 +6,48 @@ from ..translate import _ ...@@ -6,58 +6,48 @@ from ..translate import _
def match_negate_group(node): def match_negate_group(node):
""" """
--a -> a --a -> a
--ab -> ab -(a * ... * -b) -> ab
-(-ab + c) -> --ab - c
-(a + b + ... + z) -> -a + -b + ... + -z -(a + b + ... + z) -> -a + -b + ... + -z
""" """
assert node.is_op(OP_NEG) assert node.negated
val = node[0] if node.negated == 2:
if val.is_op(OP_NEG):
# --a # --a
return [P(node, double_negation, (node,))] return [P(node, double_negation, (node,))]
if not val.is_leaf: if not node.is_leaf:
scope = get_scope(val) scope = get_scope(node)
if not any(map(lambda n: n.is_op(OP_NEG), scope)):
return []
if val.is_op(OP_MUL): if node.is_op(OP_MUL) and any(map(lambda n: n.negated, scope)):
# --ab # -(-a)b
return [P(node, negate_polynome, (node, scope))] return [P(node, negate_group, (node, scope))]
elif val.is_op(OP_ADD): if node.is_op(OP_ADD):
# -(ab + c) -> -ab - c # -(ab + c) -> -ab - c
# -(-ab + c) -> ab - c # -(-ab + c) -> ab - c
return [P(node, negate_group, (node, scope))] return [P(node, negate_polynome, (node, scope))]
return [] return []
def negate_polynome(root, args): def negate_group(root, args):
""" """
# -a * -3c -> a * 3c -(a * -3c) -> a * 3c
--a * 3c -> a * 3c -(a * ... * -b) -> ab
--ab -> ab
--abc -> abc
""" """
node, scope = args node, scope = args
for i, n in enumerate(scope): for i, n in enumerate(scope):
# XXX: validate this property! if n.negated:
if n.is_op(OP_NEG): scope[i] = n.reduce_negation()
scope[i] = n[0]
return nary_node('*', scope)
raise RuntimeError('No negation node found in scope.') return nary_node('*', scope).reduce_negation()
MESSAGES[negate_polynome] = _('Apply negation to the polynome {1[0]}.') MESSAGES[negate_group] = _('Apply negation to the polynome {1[0]}.')
def negate_group(root, args): def negate_polynome(root, args):
""" """
-(-ab + ... + c) -> --ab + ... + -c -(-ab + ... + c) -> --ab + ... + -c
""" """
...@@ -70,16 +60,14 @@ def negate_group(root, args): ...@@ -70,16 +60,14 @@ def negate_group(root, args):
return nary_node('+', scope) return nary_node('+', scope)
MESSAGES[negate_group] = _('Apply negation to the subexpression {1[0]}.') MESSAGES[negate_polynome] = _('Apply negation to the subexpression {1[0]}.')
def double_negation(root, args): def double_negation(root, args):
""" """
--a -> a --a -> a
""" """
node = args[0] return negate(args[0], args[0].negated - 2)
return node[0][0]
MESSAGES[double_negation] = _('Remove double negation in {1}.') MESSAGES[double_negation] = _('Remove double negation in {1}.')
...@@ -92,14 +80,12 @@ def match_negated_division(node): ...@@ -92,14 +80,12 @@ def match_negated_division(node):
assert node.is_op(OP_DIV) assert node.is_op(OP_DIV)
a, b = node a, b = node
a_neg = a.is_op(OP_NEG)
b_neg = b.is_op(OP_NEG)
if a_neg and b_neg: if a.negated and b.negated:
return [P(node, double_negated_division, (node,))] return [P(node, double_negated_division, (node,))]
elif a_neg: elif a.negated:
return [P(node, single_negated_division, (a[0], b))] return [P(node, single_negated_division, (a[0], b))]
elif b_neg: elif b.negated:
return [P(node, single_negated_division, (a, b[0]))] return [P(node, single_negated_division, (a, b[0]))]
return [] return []
...@@ -132,3 +118,6 @@ def double_negated_division(root, args): ...@@ -132,3 +118,6 @@ def double_negated_division(root, args):
MESSAGES[double_negated_division] = \ MESSAGES[double_negated_division] = \
_('Eliminate top and bottom negation in {1}.') _('Eliminate top and bottom negation in {1}.')
# TODO: negated multiplication: -a * -b = ab
from itertools import combinations from itertools import combinations
from ..node import ExpressionLeaf as Leaf, Scope, OP_DIV, OP_MUL, OP_NEG from ..node import ExpressionLeaf as Leaf, Scope, negate, OP_DIV, OP_MUL
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -16,21 +16,10 @@ def add_numerics(root, args): ...@@ -16,21 +16,10 @@ def add_numerics(root, args):
-2 + -3 -> -5 -2 + -3 -> -5
""" """
n0, n1, c0, c1 = args n0, n1, c0, c1 = args
if c0.is_op(OP_NEG):
c0 = -c0[0].value
else:
c0 = c0.value
if c1.is_op(OP_NEG):
c1 = (-c1[0].value)
else:
c1 = c1.value
scope = Scope(root) scope = Scope(root)
# Replace the left node with the new expression # Replace the left node with the new expression
scope.remove(n0, Leaf(c0 + c1)) scope.remove(n0, Leaf(c0.actual_value() + c1.actual_value()))
# Remove the right node # Remove the right node
scope.remove(n1) scope.remove(n1)
...@@ -119,20 +108,12 @@ def match_multiply_zero(node): ...@@ -119,20 +108,12 @@ def match_multiply_zero(node):
assert node.is_op(OP_MUL) assert node.is_op(OP_MUL)
left, right = node left, right = node
is_zero = lambda n: n.is_leaf and n.value == 0
if is_zero(left):
negated = right.is_op(OP_NEG)
elif is_zero(right):
negated = left.is_op(OP_NEG)
elif left.is_op(OP_NEG) and is_zero(left[0]):
negated = not right.is_op(OP_NEG)
elif right.is_op(OP_NEG) and is_zero(right[0]):
negated = not left.is_op(OP_NEG)
else:
return []
return [P(node, multiply_zero, (negated,))] if (left.is_leaf and left.value == 0) \
or (right.is_leaf and right.value == 0):
return [P(node, multiply_zero, (left.negated + right.negated,))]
return []
def multiply_zero(root, args): def multiply_zero(root, args):
...@@ -143,12 +124,7 @@ def multiply_zero(root, args): ...@@ -143,12 +124,7 @@ def multiply_zero(root, args):
0 * -a -> -0 0 * -a -> -0
-0 * -a -> 0 -0 * -a -> 0
""" """
negated = args[0] return negate(Leaf(0), args[0])
if negated:
return -Leaf(0)
else:
return Leaf(0)
MESSAGES[multiply_zero] = _('Multiplication with zero yields zero.') MESSAGES[multiply_zero] = _('Multiplication with zero yields zero.')
...@@ -168,9 +144,7 @@ def match_multiply_numerics(node): ...@@ -168,9 +144,7 @@ def match_multiply_numerics(node):
for n in Scope(node): for n in Scope(node):
if n.is_numeric(): if n.is_numeric():
numerics.append((n, n.value)) numerics.append((n, n.actual_value()))
elif n.is_op(OP_NEG) and n[0].is_numeric():
numerics.append((n, -n[0].value))
for (n0, v0), (n1, v1) in combinations(numerics, 2): for (n0, v0), (n1, v1) in combinations(numerics, 2):
p.append(P(node, multiply_numerics, (n0, n1, v0, v1))) p.append(P(node, multiply_numerics, (n0, n1, v0, v1)))
......
from itertools import combinations from itertools import combinations
from ..node import Scope, OP_ADD, OP_NEG from ..node import Scope, OP_ADD
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from .numerics import add_numerics from .numerics import add_numerics
def is_numeric_or_negated_numeric(n):
return n.is_numeric() or (n.is_op(OP_NEG) and n[0].is_numeric())
def match_combine_polynomes(node, verbose=False): def match_combine_polynomes(node, verbose=False):
""" """
n + exp + m -> exp + (n + m) n + exp + m -> exp + (n + m)
...@@ -52,7 +48,7 @@ def match_combine_polynomes(node, verbose=False): ...@@ -52,7 +48,7 @@ def match_combine_polynomes(node, verbose=False):
# roots, or: same root and exponent -> combine coefficients. # roots, or: same root and exponent -> combine coefficients.
# TODO: Addition with zero, e.g. a + 0 -> a # TODO: Addition with zero, e.g. a + 0 -> a
if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \ if c0 == 1 and c1 == 1 and e0 == 1 and e1 == 1 \
and all(map(is_numeric_or_negated_numeric, [r0, r1])): and all(map(lambda n: n.is_numeric(), [r0, r1])):
# 2 + 3 -> 5 # 2 + 3 -> 5
# 2 + -3 -> -1 # 2 + -3 -> -1
# -2 + 3 -> 1 # -2 + 3 -> 1
......
from itertools import combinations from itertools import combinations
from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \ from ..node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
OP_NEG, OP_MUL, OP_DIV, OP_POW, OP_ADD OP_MUL, OP_DIV, OP_POW, OP_ADD
from ..possibilities import Possibility as P, MESSAGES from ..possibilities import Possibility as P, MESSAGES
from ..translate import _ from ..translate import _
...@@ -91,6 +91,18 @@ def match_subtract_exponents(node): ...@@ -91,6 +91,18 @@ def match_subtract_exponents(node):
return [] return []
def subtract_exponents(root, args):
"""
a^p / a^q -> a^(p - q)
"""
a, p, q = args
return a ** (p - q)
MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
def match_multiply_exponents(node): def match_multiply_exponents(node):
""" """
(a^p)^q -> a^(pq) (a^p)^q -> a^(pq)
...@@ -105,6 +117,18 @@ def match_multiply_exponents(node): ...@@ -105,6 +117,18 @@ def match_multiply_exponents(node):
return [] return []
def multiply_exponents(root, args):
"""
(a^p)^q -> a^(pq)
"""
a, p, q = args
return a ** (p * q)
MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
def match_duplicate_exponent(node): def match_duplicate_exponent(node):
""" """
(ab)^p -> a^p * b^p (ab)^p -> a^p * b^p
...@@ -119,20 +143,49 @@ def match_duplicate_exponent(node): ...@@ -119,20 +143,49 @@ def match_duplicate_exponent(node):
return [] return []
def duplicate_exponent(root, args):
"""
(ab)^p -> a^p * b^p
(abc)^p -> a^p * b^p * c^p
"""
ab, p = args
result = ab[0] ** p
for b in ab[1:]:
result *= b ** p
return result
MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
def match_remove_negative_exponent(node): def match_remove_negative_exponent(node):
""" """
a^-p -> 1 / a^p a^-p -> 1 / a^p
""" """
assert node.is_op(OP_POW) assert node.is_op(OP_POW)
left, right = node a, p = node
if right.is_op(OP_NEG): if p.negated:
return [P(node, remove_negative_exponent, (left, right[0]))] return [P(node, remove_negative_exponent, (a, p))]
return [] return []
def remove_negative_exponent(root, args):
"""
a^-p -> 1 / a^p
"""
a, p = args
return L(1) / a ** p.reduce_negation()
MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
def match_exponent_to_root(node): def match_exponent_to_root(node):
""" """
a^(1 / m) -> sqrt(a, m) a^(1 / m) -> sqrt(a, m)
...@@ -148,6 +201,16 @@ def match_exponent_to_root(node): ...@@ -148,6 +201,16 @@ def match_exponent_to_root(node):
return [] return []
def exponent_to_root(root, args):
"""
a^(1 / m) -> sqrt(a, m)
a^(n / m) -> sqrt(a^n, m)
"""
a, n, m = args
return N('sqrt', a if n == 1 else a ** n, m)
def match_extend_exponent(node): def match_extend_exponent(node):
""" """
(a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1 (a + ... + z)^n -> (a + ... + z)(a + ... + z)^(n - 1) # n > 1
...@@ -174,66 +237,3 @@ def extend_exponent(root, args): ...@@ -174,66 +237,3 @@ def extend_exponent(root, args):
return left * left ** L(right.value - 1) return left * left ** L(right.value - 1)
return left * left return left * left
def subtract_exponents(root, args):
"""
a^p / a^q -> a^(p - q)
"""
a, p, q = args
return a ** (p - q)
MESSAGES[subtract_exponents] = _('Substract the exponents {2} and {3}.')
def multiply_exponents(root, args):
"""
(a^p)^q -> a^(pq)
"""
a, p, q = args
return a ** (p * q)
MESSAGES[multiply_exponents] = _('Multiply the exponents {2} and {3}.')
def duplicate_exponent(root, args):
"""
(ab)^p -> a^p * b^p
(abc)^p -> a^p * b^p * c^p
"""
ab, p = args
result = ab[0] ** p
for b in ab[1:]:
result *= b ** p
return result
MESSAGES[duplicate_exponent] = _('Duplicate the exponent {2}.')
def remove_negative_exponent(root, args):
"""
a^-p -> 1 / a^p
"""
a, p = args
return L(1) / a ** p
MESSAGES[remove_negative_exponent] = _('Remove negative exponent {2}.')
def exponent_to_root(root, args):
"""
a^(1 / m) -> sqrt(a, m)
a^(n / m) -> sqrt(a^n, m)
"""
a, n, m = args
return N('sqrt', a if n == 1 else a ** n, m)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment