Commit 72f94a0f authored by Taddeüs Kroes's avatar Taddeüs Kroes

Significantly improved the parser.

- Thought over and reconfigured precedences. These are now also synchronized
  with precedences in the line printer.
- The subscript operator ("_") has been added to the grammar similarly to
  exponentiation ("^"). This resolved some issues with the use of operators in
  integral bounds.
- Unary negation now has a higher precedence, but the parser moves negations up
  in the tree as far as possible (within the bounds of correctness and without
  causing cycles during reduction).
- All unit tests have been updated using the new syntax.
- graph_drawing has been updated to include the new syntax of the line printer.
parent 9cbf75f1
graph_drawing @ e17779f2
Subproject commit 107e93a1697600860f4e02175f47cf8650affd0f Subproject commit e17779f211ca67336c3a6ac9038f552f91ff2f79
...@@ -21,7 +21,7 @@ import re ...@@ -21,7 +21,7 @@ import re
sys.path.insert(0, os.path.realpath('external')) sys.path.insert(0, os.path.realpath('external'))
from graph_drawing.graph import generate_graph from graph_drawing.graph import generate_graph
from graph_drawing.line import generate_line from graph_drawing.line import generate_line, preprocess_node
from graph_drawing.node import Node, Leaf from graph_drawing.node import Node, Leaf
...@@ -70,6 +70,16 @@ OP_REWRITE_ALL = 24 ...@@ -70,6 +70,16 @@ OP_REWRITE_ALL = 24
OP_REWRITE_ALL_VERBOSE = 25 OP_REWRITE_ALL_VERBOSE = 25
OP_REWRITE = 26 OP_REWRITE = 26
# Different types of derivative
OP_PRIME = 27
OP_DXDER = 28
OP_PARENS = 29
OP_BRACKETS = 30
OP_CBRACKETS = 31
UNARY_FUNCTIONS = [OP_INT, OP_DXDER, OP_LOG]
# Special identifiers # Special identifiers
PI = 'pi' PI = 'pi'
E = 'e' E = 'e'
...@@ -102,7 +112,7 @@ OP_MAP = { ...@@ -102,7 +112,7 @@ OP_MAP = {
'tan': OP_TAN, 'tan': OP_TAN,
'sqrt': OP_SQRT, 'sqrt': OP_SQRT,
'int': OP_INT, 'int': OP_INT,
'der': OP_DER, '\'': OP_PRIME,
'solve': OP_SOLVE, 'solve': OP_SOLVE,
'log': OP_LOG, 'log': OP_LOG,
'=': OP_EQ, '=': OP_EQ,
...@@ -114,9 +124,13 @@ OP_MAP = { ...@@ -114,9 +124,13 @@ OP_MAP = {
} }
OP_VALUE_MAP = dict([(v, k) for k, v in OP_MAP.iteritems()]) OP_VALUE_MAP = dict([(v, k) for k, v in OP_MAP.iteritems()])
OP_MAP['ln'] = OP_LOG
OP_VALUE_MAP[OP_INT_INDEF] = 'indef' OP_VALUE_MAP[OP_INT_INDEF] = 'indef'
OP_VALUE_MAP[OP_ABS] = 'abs' OP_VALUE_MAP[OP_ABS] = '||'
OP_VALUE_MAP[OP_DXDER] = 'd/d'
OP_VALUE_MAP[OP_PARENS] = '()'
OP_VALUE_MAP[OP_BRACKETS] = '[]'
OP_VALUE_MAP[OP_CBRACKETS] = '{}'
OP_MAP['ln'] = OP_LOG
TOKEN_MAP = { TOKEN_MAP = {
OP_COMMA: 'COMMA', OP_COMMA: 'COMMA',
...@@ -133,9 +147,10 @@ TOKEN_MAP = { ...@@ -133,9 +147,10 @@ TOKEN_MAP = {
OP_COS: 'FUNCTION', OP_COS: 'FUNCTION',
OP_TAN: 'FUNCTION', OP_TAN: 'FUNCTION',
OP_INT: 'INTEGRAL', OP_INT: 'INTEGRAL',
OP_DER: 'FUNCTION', OP_DXDER: 'DERIVATIVE',
OP_PRIME: 'PRIME',
OP_SOLVE: 'FUNCTION', OP_SOLVE: 'FUNCTION',
OP_LOG: 'FUNCTION', OP_LOG: 'LOGARITHM',
OP_EQ: 'EQ', OP_EQ: 'EQ',
OP_POSSIBILITIES: 'POSSIBILITIES', OP_POSSIBILITIES: 'POSSIBILITIES',
OP_HINT: 'HINT', OP_HINT: 'HINT',
...@@ -152,6 +167,11 @@ def to_expression(obj): ...@@ -152,6 +167,11 @@ def to_expression(obj):
return ExpressionLeaf(obj) return ExpressionLeaf(obj)
def bounds_str(f, a, b):
left = str(ExpressionNode(OP_SUBSCRIPT, f, a, no_spacing=True))
return left + str(ExpressionNode(OP_POW, Leaf(1), b, no_spacing=True))[1:]
class ExpressionBase(object): class ExpressionBase(object):
def __lt__(self, other): def __lt__(self, other):
""" """
...@@ -208,7 +228,8 @@ class ExpressionBase(object): ...@@ -208,7 +228,8 @@ class ExpressionBase(object):
return copy.deepcopy(self) return copy.deepcopy(self)
def is_op(self, *ops): def is_op(self, *ops):
return not self.is_leaf and self.op in ops return not self.is_leaf and (self.op in ops or
(self.op in (OP_DXDER, OP_PRIME) and OP_DER in ops))
def is_power(self, exponent=None): def is_power(self, exponent=None):
if self.is_leaf or self.op != OP_POW: if self.is_leaf or self.op != OP_POW:
...@@ -266,9 +287,9 @@ class ExpressionBase(object): ...@@ -266,9 +287,9 @@ class ExpressionBase(object):
return self.negate(-n) return self.negate(-n)
def negate(self, n=1): def negate(self, n=1, clone=True):
"""Negate the node n times.""" """Negate the node n times."""
return negate(self, self.negated + n, clone=True) return negate(self, self.negated + n, clone=clone)
def contains(self, node, include_self=True): def contains(self, node, include_self=True):
""" """
...@@ -290,6 +311,7 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -290,6 +311,7 @@ class ExpressionNode(Node, ExpressionBase):
super(ExpressionNode, self).__init__(*args, **kwargs) super(ExpressionNode, self).__init__(*args, **kwargs)
self.type = TYPE_OPERATOR self.type = TYPE_OPERATOR
op = args[0] op = args[0]
self.parens = False
if isinstance(op, str): if isinstance(op, str):
self.value = op self.value = op
...@@ -298,38 +320,6 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -298,38 +320,6 @@ class ExpressionNode(Node, ExpressionBase):
self.value = OP_VALUE_MAP[op] self.value = OP_VALUE_MAP[op]
self.op = op self.op = op
def construct_derivative(self, children):
f = children[0]
if len(children) < 2:
# der(der(x ^ 2)) -> [x ^ 2]''
if self[0].is_op(OP_DER) and len(self[0]) < 2:
return f + '\''
# der(x ^ 2) -> [x ^ 2]'
return '[' + f + ']\''
# der(x ^ 2, x) -> d/dx (x ^ 2)
return 'd/d%s (%s)' % (children[1], f)
def construct_logarithm(self, children):
if self[0].is_op(OP_ABS):
content = children[0]
else:
content = '(' + children[0] + ')'
# log(a, e) -> ln(a)
if self[1].is_identifier(E):
return 'ln%s' % content
# log(a, 10) -> log(a)
if self[1] == 10:
return 'log%s' % content
# log(a, 2) -> log_2(a)
if children[1].isdigit():
return 'log_%s%s' % (children[1], content)
def construct_integral(self, children): def construct_integral(self, children):
# Make sure that any needed parentheses around f(x) are generated, # Make sure that any needed parentheses around f(x) are generated,
# and append ' dx' to it (result 'f(x) dx') # and append ' dx' to it (result 'f(x) dx')
...@@ -375,8 +365,8 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -375,8 +365,8 @@ class ExpressionNode(Node, ExpressionBase):
return '|%s|' % children[0] return '|%s|' % children[0]
constructors = { constructors = {
OP_DER: self.construct_derivative, #OP_DER: self.construct_derivative,
OP_LOG: self.construct_logarithm, #OP_LOG: self.construct_logarithm,
OP_INT: self.construct_integral, OP_INT: self.construct_integral,
OP_INT_INDEF: self.construct_indef_integral OP_INT_INDEF: self.construct_indef_integral
} }
...@@ -393,9 +383,56 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -393,9 +383,56 @@ class ExpressionNode(Node, ExpressionBase):
and len(self) == 1 and self[0].is_op(OP_ABS): and len(self) == 1 and self[0].is_op(OP_ABS):
return self.title() + children[0] return self.title() + children[0]
def arity(self):
if self.op in UNARY_FUNCTIONS:
return 1
if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE):
return 1
return len(self)
def operator(self):
if self.op == OP_LOG:
base = self[1].value
if base == DEFAULT_LOGARITHM_BASE:
return self.value
if base == E:
return 'ln'
return '%s_%s' % (self.value, str(self[1]))
if self.op == OP_DXDER:
return self.value + str(self[1])
if self.op == OP_INT and len(self) == 4:
return bounds_str(Leaf('int'), self[2], self[3])
return self.value
def is_postfix(self):
return self.op in (OP_PRIME, OP_INT_INDEF)
def __str__(self): # pragma: nocover def __str__(self): # pragma: nocover
return generate_line(self) return generate_line(self)
def custom_line(self):
if self.op == OP_INT_INDEF:
Fx, a, b = self
return bounds_str(ExpressionNode(OP_BRACKETS, Fx), a, b)
def preprocess_str_exp(self):
if self.op == OP_PRIME and not self[0].is_op(OP_PRIME):
self[0] = ExpressionNode(OP_BRACKETS, self[0])
def postprocess_str(self, s):
if self.op == OP_INT:
return '%s d%s' % (s, self[1])
return s
def __eq__(self, other): def __eq__(self, other):
""" """
Check strict equivalence. Check strict equivalence.
...@@ -407,7 +444,7 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -407,7 +444,7 @@ class ExpressionNode(Node, ExpressionBase):
self.nodes[self.nodes.index(old_child)] = new_child self.nodes[self.nodes.index(old_child)] = new_child
def graph(self): # pragma: nocover def graph(self): # pragma: nocover
return generate_graph(negation_to_node(self)) return generate_graph(preprocess_node(self))
def extract_polynome_properties(self): def extract_polynome_properties(self):
""" """
...@@ -525,6 +562,7 @@ class ExpressionLeaf(Leaf, ExpressionBase): ...@@ -525,6 +562,7 @@ class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs) super(ExpressionLeaf, self).__init__(*args, **kwargs)
self.type = TYPE_MAP[type(args[0])] self.type = TYPE_MAP[type(args[0])]
self.parens = False
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -628,16 +666,22 @@ class Scope(object): ...@@ -628,16 +666,22 @@ class Scope(object):
def replace(self, node, replacement): def replace(self, node, replacement):
self.remove(node, replacement=replacement) self.remove(node, replacement=replacement)
#def as_nary_node(self):
def as_real_nary_node(self):
return ExpressionNode(self.node.op, *self.nodes) \
.negate(self.node.negated, clone=False)
#def as_binary_node(self):
def as_nary_node(self): def as_nary_node(self):
return nary_node(self.node.op, self.nodes).negate(self.node.negated) return nary_node(self.node.op, self.nodes) \
#return negate(nary_node(self.node.op, self.nodes), self.node.negated) .negate(self.node.negated, clone=False)
def all_except(self, node): def all_except(self, node):
before = range(0, node.scope_index) before = range(0, node.scope_index)
after = range(node.scope_index + 1, len(self)) after = range(node.scope_index + 1, len(self))
nodes = [self[i] for i in before + after] nodes = [self[i] for i in before + after]
return nary_node(self.node.op, nodes).negate(self.node.negated) return negate(nary_node(self.node.op, nodes), self.node.negated)
def nary_node(operator, scope): def nary_node(operator, scope):
...@@ -663,6 +707,14 @@ def get_scope(node): ...@@ -663,6 +707,14 @@ def get_scope(node):
else: else:
scope.append(child) scope.append(child)
#for child in node:
# if child.is_op(node.op) and (not child.negated or node.op == OP_MUL):
# sub_scope = get_scope(child)
# sub_scope[0] = sub_scope[0].negate(child.negated)
# scope += sub_scope
# else:
# scope.append(child)
return scope return scope
...@@ -716,10 +768,13 @@ def tan(*args): ...@@ -716,10 +768,13 @@ def tan(*args):
return ExpressionNode(OP_TAN, *args) return ExpressionNode(OP_TAN, *args)
def log(exponent, base=10): def log(exponent, base=None):
""" """
Create a logarithm function node (default base is 10). Create a logarithm function node (default base is 10).
""" """
if base is None:
base = DEFAULT_LOGARITHM_BASE
if not isinstance(base, ExpressionLeaf): if not isinstance(base, ExpressionLeaf):
base = ExpressionLeaf(base) base = ExpressionLeaf(base)
...@@ -737,7 +792,7 @@ def der(f, x=None): ...@@ -737,7 +792,7 @@ def der(f, x=None):
""" """
Create a derivative node. Create a derivative node.
""" """
return ExpressionNode(OP_DER, f, x) if x else ExpressionNode(OP_DER, f) return ExpressionNode(OP_DXDER, f, x) if x else ExpressionNode(OP_PRIME, f)
def integral(*args): def integral(*args):
...@@ -766,21 +821,3 @@ def sqrt(exp): ...@@ -766,21 +821,3 @@ def sqrt(exp):
Create a square root node. Create a square root node.
""" """
return ExpressionNode(OP_SQRT, exp) return ExpressionNode(OP_SQRT, exp)
def negation_to_node(node):
"""
Recursively replace negation flags inside a node by explicit unary negation
nodes.
"""
if node.negated:
negations = node.negated
node = negate(node, 0)
for i in range(negations):
node = ExpressionNode('-', node)
if node.is_leaf:
return node
return ExpressionNode(node.op, *map(negation_to_node, node))
...@@ -27,12 +27,13 @@ sys.path.insert(1, EXTERNAL_MODS) ...@@ -27,12 +27,13 @@ sys.path.insert(1, EXTERNAL_MODS)
from pybison import BisonParser, BisonSyntaxError from pybison import BisonParser, BisonSyntaxError
from graph_drawing.graph import generate_graph from graph_drawing.graph import generate_graph
from graph_drawing.line import pred
from node import ExpressionNode as Node, \ from node import ExpressionNode as Node, \
ExpressionLeaf as Leaf, OP_MAP, OP_DER, TOKEN_MAP, TYPE_OPERATOR, \ ExpressionLeaf as Leaf, OP_MAP, OP_DXDER, TOKEN_MAP, TYPE_OPERATOR, \
OP_COMMA, OP_MUL, OP_POW, OP_LOG, OP_ADD, Scope, E, OP_ABS, \ OP_COMMA, OP_MUL, OP_POW, OP_LOG, OP_ADD, Scope, E, OP_ABS, \
DEFAULT_LOGARITHM_BASE, OP_VALUE_MAP, SPECIAL_TOKENS, OP_INT, \ DEFAULT_LOGARITHM_BASE, OP_VALUE_MAP, SPECIAL_TOKENS, OP_INT, \
OP_INT_INDEF, negation_to_node OP_INT_INDEF, INFINITY, OP_PRIME, OP_DIV
from rules.utils import find_variable from rules.utils import find_variable
from rules.precedences import IMPLICIT_RULES from rules.precedences import IMPLICIT_RULES
from strategy import find_possibilities from strategy import find_possibilities
...@@ -45,6 +46,9 @@ import re ...@@ -45,6 +46,9 @@ import re
# Rewriting an expression is stopped after this number of steps is passed. # Rewriting an expression is stopped after this number of steps is passed.
MAXIMUM_REWRITE_STEPS = 30 MAXIMUM_REWRITE_STEPS = 30
# Precedence of the TIMES operator ("*")
TIMES_PRED = pred(Node(OP_MUL))
# Check for n-ary operator in child nodes # Check for n-ary operator in child nodes
def combine(op, op_type, *nodes): def combine(op, op_type, *nodes):
...@@ -70,12 +74,15 @@ def find_integration_variable(exp): ...@@ -70,12 +74,15 @@ def find_integration_variable(exp):
if len(scope) > 2 and scope[-2] == 'd' and scope[-1].is_identifier(): if len(scope) > 2 and scope[-2] == 'd' and scope[-1].is_identifier():
x = scope[-1] x = scope[-1]
scope.nodes = scope[:-2] scope.nodes = scope[:-2]
return scope.as_nary_node(), x return scope.as_nary_node(), x
return exp, find_variable(exp) return exp, find_variable(exp)
def apply_operator_negation(op, exp):
exp.negated += len(op) - 1
class Parser(BisonParser): class Parser(BisonParser):
""" """
Implements the calculator parser. Grammar rules are defined in the method Implements the calculator parser. Grammar rules are defined in the method
...@@ -95,9 +102,8 @@ class Parser(BisonParser): ...@@ -95,9 +102,8 @@ class Parser(BisonParser):
# TODO: add a runtime check to verify that this token list match the list # TODO: add a runtime check to verify that this token list match the list
# of tokens of the lex script. # of tokens of the lex script.
tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH', tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
'LPAREN', 'RPAREN', 'FUNCTION', 'FUNCTION_LPAREN', 'LBRACKET', 'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET',
'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE', 'PRIME', 'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \
'DERIVATIVE'] \
+ filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values()) + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
# ------------------------------ # ------------------------------
...@@ -108,19 +114,16 @@ class Parser(BisonParser): ...@@ -108,19 +114,16 @@ class Parser(BisonParser):
('left', ('OR', )), ('left', ('OR', )),
('left', ('AND', )), ('left', ('AND', )),
('left', ('EQ', )), ('left', ('EQ', )),
('left', ('MINUS', 'PLUS', 'NEG')), ('left', ('MINUS', 'PLUS')),
('left', ('INTEGRAL', 'DERIVATIVE')), ('nonassoc', ('INTEGRAL', 'DERIVATIVE')),
('left', ('TIMES', )), ('left', ('TIMES', )),
('left', ('PRIME', )),
('left', ('DIVIDE', )), ('left', ('DIVIDE', )),
('right', ('FUNCTION', )), ('nonassoc', ('PRIME', )),
('right', ('POW', )), ('nonassoc', ('NEG', )),
('left', ('SUB', )), ('nonassoc', ('FUNCTION', 'LOGARITHM')),
('right', ('FUNCTION_LPAREN', )), ('right', ('POW', 'SUB')),
) )
interactive = 0
def __init__(self, **kwargs): def __init__(self, **kwargs):
BisonParser.__init__(self, **kwargs) BisonParser.__init__(self, **kwargs)
self.interactive = kwargs.get('interactive', 0) self.interactive = kwargs.get('interactive', 0)
...@@ -199,6 +202,7 @@ class Parser(BisonParser): ...@@ -199,6 +202,7 @@ class Parser(BisonParser):
+ '|([0-9])\s*([' + rsv + 'a-z])' # 4a -> 4 * a + '|([0-9])\s*([' + rsv + 'a-z])' # 4a -> 4 * a
+ '|([' + rsv + 'a-z])([0-9])' # a4 -> a ^ 4 + '|([' + rsv + 'a-z])([0-9])' # a4 -> a ^ 4
+ '|([' + rsv + '0-9])(\s+[0-9]))' # 4 4 -> 4 * 4 + '|([' + rsv + '0-9])(\s+[0-9]))' # 4 4 -> 4 * 4
# FIXME: Last line is a bit useless
) )
def preprocess_data(match): def preprocess_data(match):
...@@ -241,8 +245,8 @@ class Parser(BisonParser): ...@@ -241,8 +245,8 @@ class Parser(BisonParser):
# Add parentheses to integrals with matching 'dx' so that the 'dx' acts # Add parentheses to integrals with matching 'dx' so that the 'dx' acts
# as a right parenthesis for the integral function # as a right parenthesis for the integral function
data = re.sub(r'(int(?:_.+\^.+\*)?)(.+?)(\*d\*[a-z])', #data = re.sub(r'(int(?:_.+(?:\^.+)?\*)?)(.+?)(\*d\*[a-z])',
'\\1(\\2)\\3', data) # '\\1(\\2)\\3', data)
if self.verbose and data_before != data: # pragma: nocover if self.verbose and data_before != data: # pragma: nocover
print 'hook_read_after() modified the input data:' print 'hook_read_after() modified the input data:'
...@@ -265,7 +269,6 @@ class Parser(BisonParser): ...@@ -265,7 +269,6 @@ class Parser(BisonParser):
if self.possibilities is not None: if self.possibilities is not None:
if self.verbose: if self.verbose:
print 'Expression has not changed, not updating possibilities' print 'Expression has not changed, not updating possibilities'
return return
self.possibilities = find_possibilities(self.root_node) self.possibilities = find_possibilities(self.root_node)
...@@ -455,9 +458,8 @@ class Parser(BisonParser): ...@@ -455,9 +458,8 @@ class Parser(BisonParser):
""" """
debug : GRAPH exp debug : GRAPH exp
""" """
if option == 0: if option == 0:
print generate_graph(negation_to_node(values[1])) print values[1].graph()
return values[1] return values[1]
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
...@@ -474,8 +476,6 @@ class Parser(BisonParser): ...@@ -474,8 +476,6 @@ class Parser(BisonParser):
| binary | binary
| nary | nary
""" """
# | concat
if option == 0: # rule: NUMBER if option == 0: # rule: NUMBER
# TODO: A bit hacky, this achieves long integers and floats. # TODO: A bit hacky, this achieves long integers and floats.
value = float(values[0]) if '.' in values[0] else int(values[0]) value = float(values[0]) if '.' in values[0] else int(values[0])
...@@ -486,6 +486,7 @@ class Parser(BisonParser): ...@@ -486,6 +486,7 @@ class Parser(BisonParser):
if 2 <= option <= 4: # rule: LPAREN exp RPAREN | LBRACKET exp RBRACKET if 2 <= option <= 4: # rule: LPAREN exp RPAREN | LBRACKET exp RBRACKET
# | LCBRACKET exp RCBRACKET # | LCBRACKET exp RCBRACKET
values[1].parens = pred(values[1]) > TIMES_PRED
return values[1] return values[1]
if 5 <= option <= 7: # rule: unary | binary | nary if 5 <= option <= 7: # rule: unary | binary | nary
...@@ -496,129 +497,159 @@ class Parser(BisonParser): ...@@ -496,129 +497,159 @@ class Parser(BisonParser):
def on_unary(self, target, option, names, values): def on_unary(self, target, option, names, values):
""" """
unary : MINUS exp unary : MINUS exp %prec NEG
| FUNCTION_LPAREN exp RPAREN
| FUNCTION exp | FUNCTION exp
| raised_function exp %prec FUNCTION
| DERIVATIVE exp | DERIVATIVE exp
| exp PRIME | exp PRIME
| INTEGRAL exp | INTEGRAL exp
| integral_bounds TIMES exp %prec INTEGRAL | integral_bounds exp %prec INTEGRAL
| LBRACKET exp RBRACKET lbnd ubnd
| PIPE exp PIPE | PIPE exp PIPE
| LOGARITHM exp
| logarithm_subscript exp %prec LOGARITHM
| TIMES exp
""" """
if option == 0: # rule: NEG exp if option == 0: # rule: MINUS exp
values[1].negated += 1 values[1].negated += 1
return values[1] return values[1]
if option in (1, 2): # rule: FUNCTION_LPAREN exp RPAREN | FUNCTION exp if option == 1: # rule: FUNCTION exp
op = values[0].split(' ', 1)[0]
if op == 'ln':
return Node(OP_LOG, values[1], Leaf(E))
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
return Node(op, *values[1]) return Node(values[0], *values[1])
if op == OP_VALUE_MAP[OP_LOG]:
return Node(OP_LOG, values[1], Leaf(DEFAULT_LOGARITHM_BASE))
m = re.match(r'^log_([0-9]+|[a-zA-Z])', op) return Node(values[0], values[1])
if m: if option == 2: # rule: raised_function exp
value = m.group(1) func, exponent = values[0]
if value.isdigit(): if values[1].is_op(OP_COMMA):
value = int(value) return Node(OP_POW, Node(func, *values[1]), exponent)
return Node(OP_LOG, values[1], Leaf(value))
return Node(op, values[1]) return Node(OP_POW, Node(func, values[1]), exponent)
if option == 3: # rule: DERIVATIVE exp if option == 3: # rule: DERIVATIVE exp
# DERIVATIVE looks like 'd/d*x*' -> extract the 'x' # DERIVATIVE looks like 'd/d*x' -> extract the 'x'
return Node(OP_DER, values[1], Leaf(values[0][-2])) return Node(OP_DXDER, values[1], Leaf(values[0][-1]))
if option == 4: # rule: exp PRIME if option == 4: # rule: exp PRIME
return Node(OP_DER, values[0]) return Node(OP_PRIME, values[0])
if option == 5: # rule: INTEGRAL exp if option == 5: # rule: INTEGRAL exp
fx, x = find_integration_variable(values[1]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x) return Node(OP_INT, fx, x)
if option == 6: # rule: integral_bounds TIMES exp if option == 6: # rule: integral_bounds exp
lbnd, ubnd = values[0] lbnd, ubnd = values[0]
fx, x = find_integration_variable(values[2]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x, lbnd, ubnd) return Node(OP_INT, fx, x, lbnd, ubnd)
if option == 7: # rule: LBRACKET exp RBRACKET lbnd ubnd if option == 7: # rule: PIPE exp PIPE
return Node(OP_INT_INDEF, values[1], values[3], values[4])
if option == 8: # rule: PIPE exp PIPE
return Node(OP_ABS, values[1]) return Node(OP_ABS, values[1])
raise BisonSyntaxError('Unsupported option %d in target "%s".' if option == 8: # rule: LOGARITHM exp
% (option, target)) # pragma: nocover if values[1].is_op(OP_COMMA):
return Node(OP_LOG, *values[1])
def on_integral_bounds(self, target, option, names, values): if values[0] == 'ln':
""" base = E
integral_bounds : INTEGRAL lbnd ubnd else:
""" base = DEFAULT_LOGARITHM_BASE
if option == 0: # rule: INTEGRAL lbnd ubnd
return values[1], values[2]
raise BisonSyntaxError('Unsupported option %d in target "%s".' return Node(OP_LOG, values[1], Leaf(base))
% (option, target)) # pragma: nocover
def on_lbnd(self, target, option, names, values): if option == 9: # rule: logarithm_subscript exp
""" if values[1].is_op(OP_COMMA):
lbnd : SUB exp raise BisonSyntaxError('Shortcut logarithm base "log_%s" does '
""" 'not support additional arguments.' % (values[0]))
if option == 0: # rule: SUB exp
return Node(OP_LOG, values[1], values[0])
if option == 10: # rule: TIMES exp
return values[1] return values[1]
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
% (option, target)) # pragma: nocover % (option, target)) # pragma: nocover
def on_ubnd(self, target, option, names, values): def on_raised_function(self, target, option, names, values):
""" """
ubnd : POW exp raised_function : FUNCTION POW exp
| LOGARITHM POW exp
""" """
if option == 0: # rule: POW exp # | logarithm_subscript POW exp
return values[1] if option in (0, 1): # rule: {FUNCTION,LOGARITHM} POW exp
apply_operator_negation(values[1], values[2])
raise BisonSyntaxError('Unsupported option %d in target "%s".' return values[0], values[2]
% (option, target)) # pragma: nocover
def on_power(self, target, option, names, values): def on_logarithm_subscript(self, target, option, names, values):
""" """
power : exp POW exp logarithm_subscript : LOGARITHM SUB exp
""" """
if option == 0: # rule: LOGARITHM SUB exp
apply_operator_negation(values[1], values[2])
return values[2]
if option == 0: # rule: exp POW exp def on_integral_bounds(self, target, option, names, values):
return values[0], values[2] """
integral_bounds : INTEGRAL SUB exp
"""
if option == 0: # rule: INTEGRAL SUB exp
if values[2].is_op(OP_POW):
lbnd, ubnd = values[2]
else:
lbnd = values[2]
ubnd = Leaf(INFINITY)
raise BisonSyntaxError('Unsupported option %d in target "%s".' apply_operator_negation(values[1], lbnd)
% (option, target)) # pragma: nocover return lbnd, ubnd
def on_binary(self, target, option, names, values): def on_binary(self, target, option, names, values):
""" """
binary : exp PLUS exp binary : exp TIMES exp
| exp TIMES exp | exp PLUS exp
| exp DIVIDE exp
| exp EQ exp | exp EQ exp
| exp AND exp | exp AND exp
| exp OR exp | exp OR exp
| exp DIVIDE exp
| exp MINUS exp | exp MINUS exp
| power | exp POW exp
| exp SUB exp
""" """
if 0 <= option <= 5: # rule: exp {PLUS,TIMES,DIVIDE,EQ,AND,OR} exp if option == 0: # rule: exp TIMES exp
first = values[0]
node = Node(values[1], first, values[2])
if first.negated and not first.parens:
node.negated += first.negated
first.negated = 0
return node
if 1 <= option <= 4: # rule: exp {PLUS,EQ,AND,OR} exp
return Node(values[1], values[0], values[2]) return Node(values[1], values[0], values[2])
if option == 5: # rule: exp DIVIDE exp
top = values[0]
bottom = values[2]
negated = 0
if top.negated and not top.parens:
negated = top.negated
top.negated = 0
if top.is_op(OP_MUL) and bottom.is_op(OP_MUL):
dtop, fx = top
dbot, x = bottom
if dtop.is_identifier('d') and dbot.is_identifier('d') \
and x.is_identifier():
# (d (fx)) / (dx)
return Node(OP_DXDER, fx, x, negated=negated)
return Node(OP_DIV, top, bottom, negated=negated)
if option == 6: # rule: exp MINUS exp if option == 6: # rule: exp MINUS exp
right = values[2] right = values[2]
right.negated += 1 right.negated += 1
...@@ -628,8 +659,22 @@ class Parser(BisonParser): ...@@ -628,8 +659,22 @@ class Parser(BisonParser):
return Node(OP_ADD, values[0], right) return Node(OP_ADD, values[0], right)
if option == 7: # rule: power if option == 7: # rule: exp POW exp
return Node(OP_POW, *values[0]) apply_operator_negation(values[1], values[2])
return Node(OP_POW, values[0], values[2])
if option == 8: # rule: exp SUB exp
bounds = values[2]
if bounds.is_op(OP_POW):
lbnd, ubnd = bounds
else:
lbnd = bounds
ubnd = Leaf(INFINITY)
lbnd.negated += len(values[1]) - 1
return Node(OP_INT_INDEF, values[0], lbnd, ubnd)
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
% (option, target)) # pragma: nocover % (option, target)) # pragma: nocover
...@@ -665,8 +710,6 @@ class Parser(BisonParser): ...@@ -665,8 +710,6 @@ class Parser(BisonParser):
# Put all functions in a single regex # Put all functions in a single regex
if functions: if functions:
operators += '("%s")[ ]*"(" { returntoken(FUNCTION_LPAREN); }\n' \
% '"|"'.join(functions)
operators += '("%s") { returntoken(FUNCTION); }\n' \ operators += '("%s") { returntoken(FUNCTION); }\n' \
% '"|"'.join(functions) % '"|"'.join(functions)
...@@ -710,7 +753,7 @@ class Parser(BisonParser): ...@@ -710,7 +753,7 @@ class Parser(BisonParser):
%% %%
d[ ]*"/"[ ]*"d*"[a-z]"*" { returntoken(DERIVATIVE); } d[ ]*"/"[ ]*"d*"[a-z] { returntoken(DERIVATIVE); }
[0-9]+"."?[0-9]* { returntoken(NUMBER); } [0-9]+"."?[0-9]* { returntoken(NUMBER); }
[a-zA-Z] { returntoken(IDENTIFIER); } [a-zA-Z] { returntoken(IDENTIFIER); }
"(" { returntoken(LPAREN); } "(" { returntoken(LPAREN); }
...@@ -719,10 +762,7 @@ class Parser(BisonParser): ...@@ -719,10 +762,7 @@ class Parser(BisonParser):
"]" { returntoken(RBRACKET); } "]" { returntoken(RBRACKET); }
"{" { returntoken(LCBRACKET); } "{" { returntoken(LCBRACKET); }
"}" { returntoken(RCBRACKET); } "}" { returntoken(RCBRACKET); }
"'" { returntoken(PRIME); }
"|" { returntoken(PIPE); } "|" { returntoken(PIPE); }
log_([0-9]+|[a-zA-Z])"*(" { returntoken(FUNCTION_LPAREN); }
log_([0-9]+|[a-zA-Z])"*" { returntoken(FUNCTION); }
""" + operators + r""" """ + operators + r"""
"raise" { returntoken(RAISE); } "raise" { returntoken(RAISE); }
"graph" { returntoken(GRAPH); } "graph" { returntoken(GRAPH); }
...@@ -736,4 +776,5 @@ class Parser(BisonParser): ...@@ -736,4 +776,5 @@ class Parser(BisonParser):
yywrap() { return(1); } yywrap() { return(1); }
""" """
#int[ ]*"(" { returntoken(FUNCTION_LPAREN); } #_-+ { returntoken(SUB); }
#"^"-+ { returntoken(POW); }
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from node import TYPE_OPERATOR from node import TYPE_OPERATOR, OP_MUL, Scope
import re import re
...@@ -80,6 +80,20 @@ def find_parent_node(root, child): ...@@ -80,6 +80,20 @@ def find_parent_node(root, child):
node = node[0] node = node[0]
def flatten_mult(node):
if node.is_leaf:
return node
if node.is_op(OP_MUL):
scope = Scope(node)
scope.nodes = map(flatten_mult, scope)
return scope.as_nary_node()
node.nodes = map(flatten_mult, node)
return node
def apply_suggestion(root, suggestion): def apply_suggestion(root, suggestion):
# TODO: clone the root node before modifying. After deep copying the root # TODO: clone the root node before modifying. After deep copying the root
# node, the subtree_map cannot be used since the hash() of each node in the # node, the subtree_map cannot be used since the hash() of each node in the
...@@ -100,6 +114,7 @@ def apply_suggestion(root, suggestion): ...@@ -100,6 +114,7 @@ def apply_suggestion(root, suggestion):
if parent_node: if parent_node:
parent_node.substitute(suggestion.root, subtree) parent_node.substitute(suggestion.root, subtree)
return root else:
root = subtree
return subtree return flatten_mult(root)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_NEG, OP_SIN, OP_COS, \ from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_NEG, OP_SIN, OP_COS, \
OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS, OP_SQRT, \ OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS, OP_SQRT, \
OP_AND, OP_OR OP_AND, OP_OR, OP_DXDER, OP_PRIME
from .groups import match_combine_groups from .groups import match_combine_groups
from .factors import match_expand from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \ from .powers import match_add_exponents, match_subtract_exponents, \
...@@ -85,3 +85,4 @@ RULES = { ...@@ -85,3 +85,4 @@ RULES = {
OP_AND: [match_multiple_equations, match_double_case], OP_AND: [match_multiple_equations, match_double_case],
OP_OR: [match_double_case], OP_OR: [match_double_case],
} }
RULES[OP_DXDER] = RULES[OP_PRIME] = RULES[OP_DER]
...@@ -188,9 +188,7 @@ def power_rule(root, args): ...@@ -188,9 +188,7 @@ def power_rule(root, args):
""" """
[f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]' [f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]'
""" """
x = second_arg(root) return der(L(E) ** ln(root[0]), second_arg(root))
return der(L(E) ** ln(root[0]), x)
MESSAGES[power_rule] = \ MESSAGES[power_rule] = \
......
...@@ -60,9 +60,9 @@ def match_expand(node): ...@@ -60,9 +60,9 @@ def match_expand(node):
def expand(root, args): def expand(root, args):
""" """
(a + b)(c + d) -> ac + ad + bc + bd
(a + b)c -> ac + bc
a(b + c) -> ab + ac a(b + c) -> ab + ac
(a + b)c -> ac + bc
(a + b)(c + d) -> ac + ad + bc + bd
etc.. etc..
""" """
scope, left, right = args scope, left, right = args
......
...@@ -226,7 +226,7 @@ def raised_base(root, args): ...@@ -226,7 +226,7 @@ def raised_base(root, args):
return args[0] return args[0]
MESSAGES[raised_base] = _('Apply `g ^ (log_(g)(a)) = a` on {0}.') MESSAGES[raised_base] = _('Apply `g ^ log_g(a) = a` to {0}.')
def match_factor_out_exponent(node): def match_factor_out_exponent(node):
......
...@@ -72,12 +72,7 @@ def add_numerics(root, args): ...@@ -72,12 +72,7 @@ def add_numerics(root, args):
-2 + -3 -> -5 -2 + -3 -> -5
""" """
scope, c0, c1 = args scope, c0, c1 = args
value = c0.actual_value() + c1.actual_value() scope.replace(c0, Leaf(c0.actual_value() + c1.actual_value()))
# Replace the left node with the new expression
scope.replace(c0, Leaf(abs(value), negated=int(value < 0)))
# Remove the right node
scope.remove(c1) scope.remove(c1)
return scope.as_nary_node() return scope.as_nary_node()
...@@ -193,13 +188,10 @@ def match_multiply_numerics(node): ...@@ -193,13 +188,10 @@ def match_multiply_numerics(node):
numerics = filter(is_numeric_node, scope) numerics = filter(is_numeric_node, scope)
for n in numerics: for n in numerics:
if n.negated:
continue
if n.value == 0: if n.value == 0:
p.append(P(node, multiply_zero, (n,))) p.append(P(node, multiply_zero, (n,)))
if n.value == 1: if not n.negated and n.value == 1:
p.append(P(node, multiply_one, (scope, n))) p.append(P(node, multiply_one, (scope, n)))
for c0, c1 in combinations(numerics, 2): for c0, c1 in combinations(numerics, 2):
......
...@@ -22,7 +22,7 @@ from .derivatives import chain_rule ...@@ -22,7 +22,7 @@ from .derivatives import chain_rule
from .negation import double_negation, negated_factor, negated_nominator, \ from .negation import double_negation, negated_factor, negated_nominator, \
negated_denominator, negated_zero negated_denominator, negated_zero
from .fractions import multiply_with_fraction, divide_fraction_by_term, \ from .fractions import multiply_with_fraction, divide_fraction_by_term, \
add_nominators add_nominators, division_by_one
from .integrals import factor_out_constant, integrate_variable_root from .integrals import factor_out_constant, integrate_variable_root
from .powers import remove_power_of_one from .powers import remove_power_of_one
from .sqrt import quadrant_sqrt, extract_sqrt_mult_priority from .sqrt import quadrant_sqrt, extract_sqrt_mult_priority
...@@ -37,6 +37,16 @@ HIGH = [ ...@@ -37,6 +37,16 @@ HIGH = [
# 4 / 4 + 1 / 4 -> 5 / 4 instead of 1 + 1/4 # 4 / 4 + 1 / 4 -> 5 / 4 instead of 1 + 1/4
add_nominators, add_nominators,
# Some operations are obvious, they are mostly done on-the-fly
multiply_zero,
multiply_one,
remove_zero,
double_negation,
division_by_one,
add_numerics,
multiply_numerics,
negated_factor,
] ]
...@@ -104,12 +114,12 @@ IMPLICIT_RULES = [ ...@@ -104,12 +114,12 @@ IMPLICIT_RULES = [
double_negation, double_negation,
negated_nominator, negated_nominator,
negated_denominator, negated_denominator,
multiply_one,
multiply_zero, multiply_zero,
multiply_one,
division_by_one,
negated_zero, negated_zero,
remove_zero, remove_zero,
remove_power_of_one, remove_power_of_one,
negated_factor,
add_numerics, add_numerics,
swap_factors, swap_factors,
] ]
...@@ -16,11 +16,10 @@ import sys ...@@ -16,11 +16,10 @@ import sys
from external.graph_drawing.graph import generate_graph from external.graph_drawing.graph import generate_graph
from external.graph_drawing.line import generate_line from external.graph_drawing.line import generate_line
from src.node import negation_to_node
def create_graph(node): def create_graph(node):
return generate_graph(negation_to_node(node)) return node.graph() if node else None
class ParserWrapper(object): class ParserWrapper(object):
......
...@@ -42,7 +42,7 @@ class TestB1Ch10(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestB1Ch10(unittest.TestCase):
) )
), ),
('-x^3*-2x^5', ('-x^3*-2x^5',
-(L('x') ** L(3) * -(L(2) * L('x') ** L(5))) -(L('x') ** L(3) * -L(2) * L('x') ** L(5))
), ),
('(7x^2y^3)^2/(7x^2y^3)', ('(7x^2y^3)^2/(7x^2y^3)',
N('/', N('/',
......
...@@ -19,45 +19,45 @@ class TestLeidenOefenopgave(TestCase): ...@@ -19,45 +19,45 @@ class TestLeidenOefenopgave(TestCase):
def test_1_1(self): def test_1_1(self):
self.assertRewrite([ self.assertRewrite([
'-5(x ^ 2 - 3x + 6)', '-5(x ^ 2 - 3x + 6)',
'-(5x ^ 2 + 5(-3x) + 5 * 6)', '-(5x ^ 2 + 5 * -3x + 5 * 6)',
'-(5x ^ 2 - 5 * 3x + 5 * 6)', '-(5x ^ 2 + (-15)x + 5 * 6)',
'-(5x ^ 2 - 15x + 5 * 6)', '-(5x ^ 2 + (-15)x + 30)',
'-(5x ^ 2 - 15x + 30)', '-(5x ^ 2 - 15x + 30)',
'-5x ^ 2 - -15x - 30', '-5x ^ 2 - -15x - 30',
'-5x ^ 2 + 15x - 30', '-5x ^ 2 + 15x - 30',
]) ])
return
#for exp, solution in [
for exp, solution in [ # ('-5(x^2 - 3x + 6)', '-30 + 15x - 5x ^ 2'),
('-5(x^2 - 3x + 6)', '-30 + 15x - 5x ^ 2'), # ('(x+1)^2', 'x ^ 2 + 2x + 1'),
('(x+1)^2', 'x ^ 2 + 2x + 1'), # ('(x-1)^2', 'x ^ 2 - 2x + 1'),
('(x-1)^2', 'x ^ 2 - 2x + 1'), # ('(2x+x)*x', '3x ^ 2'),
('(2x+x)*x', '3x ^ 2'), # ('-2(6x-4)^2*x', '-72x ^ 3 + 96x ^ 2 + 32x'),
('-2(6x-4)^2*x', '-72x ^ 3 + 96x ^ 2 + 32x'), # ('(4x + 5) * -(5 - 4x)', '16x^2 - 25'),
('(4x + 5) * -(5 - 4x)', '16x^2 - 25'), # ]:
]: # self.assertEqual(str(rewrite(exp)), solution)
self.assertEqual(str(rewrite(exp)), solution)
def test_1_2(self): def test_1_2(self):
self.assertRewrite([ self.assertRewrite([
'(x+1)^3', '(x + 1)(x + 1) ^ 2', '(x + 1) ^ 3',
'(x + 1)(x + 1) ^ 2',
'(x + 1)(x + 1)(x + 1)', '(x + 1)(x + 1)(x + 1)',
'(xx + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x * 1 + 1x + 1 * 1)(x + 1)',
'(x ^ (1 + 1) + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x + 1x + 1 * 1)(x + 1)',
'(x ^ 2 + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x + x + 1 * 1)(x + 1)',
'(x ^ 2 + x + 1x + 1 * 1)(x + 1)', '(xx + x + x + 1)(x + 1)',
'(x ^ 2 + x + x + 1 * 1)(x + 1)', '(x ^ (1 + 1) + x + x + 1)(x + 1)',
'(x ^ 2 + x + x + 1)(x + 1)', '(x ^ 2 + x + x + 1)(x + 1)',
'(x ^ 2 + (1 + 1)x + 1)(x + 1)', '(x ^ 2 + (1 + 1)x + 1)(x + 1)',
'(x ^ 2 + 2x + 1)(x + 1)', '(x ^ 2 + 2x + 1)(x + 1)',
'x ^ 2 * x + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1',
'x ^ (2 + 1) + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x * 1 + 1x + 1 * 1',
'x ^ 3 + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + 1x + 1 * 1',
'x ^ 3 + x ^ 2 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + x + 1 * 1',
'x ^ 3 + x ^ 2 + 2x ^ (1 + 1) + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x * 1 + 1x + 1 * 1', 'x ^ (2 + 1) + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + 1x + 1 * 1', 'x ^ 3 + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1 * 1', 'x ^ 3 + x ^ 2 + 2x ^ (1 + 1) + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1', 'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1',
'x ^ 3 + (1 + 2)x ^ 2 + 2x + x + 1', 'x ^ 3 + (1 + 2)x ^ 2 + 2x + x + 1',
'x ^ 3 + 3x ^ 2 + 2x + x + 1', 'x ^ 3 + 3x ^ 2 + 2x + x + 1',
...@@ -68,12 +68,13 @@ class TestLeidenOefenopgave(TestCase): ...@@ -68,12 +68,13 @@ class TestLeidenOefenopgave(TestCase):
def test_1_3(self): def test_1_3(self):
# (x+1)^2 -> x^2 + 2x + 1 # (x+1)^2 -> x^2 + 2x + 1
self.assertRewrite([ self.assertRewrite([
'(x+1)^2', '(x + 1)(x + 1)', '(x + 1) ^ 2',
'(x + 1)(x + 1)',
'xx + x * 1 + 1x + 1 * 1', 'xx + x * 1 + 1x + 1 * 1',
'x ^ (1 + 1) + x * 1 + 1x + 1 * 1', 'xx + x + 1x + 1 * 1',
'x ^ 2 + x * 1 + 1x + 1 * 1', 'xx + x + x + 1 * 1',
'x ^ 2 + x + 1x + 1 * 1', 'xx + x + x + 1',
'x ^ 2 + x + x + 1 * 1', 'x ^ (1 + 1) + x + x + 1',
'x ^ 2 + x + x + 1', 'x ^ 2 + x + x + 1',
'x ^ 2 + (1 + 1)x + 1', 'x ^ 2 + (1 + 1)x + 1',
'x ^ 2 + 2x + 1', 'x ^ 2 + 2x + 1',
...@@ -84,25 +85,25 @@ class TestLeidenOefenopgave(TestCase): ...@@ -84,25 +85,25 @@ class TestLeidenOefenopgave(TestCase):
self.assertRewrite([ self.assertRewrite([
'(x - 1) ^ 2', '(x - 1) ^ 2',
'(x - 1)(x - 1)', '(x - 1)(x - 1)',
'xx + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x + (-1) * -1',
'x ^ (1 + 1) + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x - -1',
'x ^ 2 + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x + 1',
'x ^ 2 - x * 1 + (-1)x + (-1)(-1)', 'xx - x * 1 + (-1)x + 1',
'x ^ 2 - x + (-1)x + (-1)(-1)', 'xx - x + (-1)x + 1',
'x ^ 2 - x - 1x + (-1)(-1)', 'xx - x - 1x + 1',
'x ^ 2 - x - x + (-1)(-1)', 'xx - x - x + 1',
'x ^ 2 - x - x - -1', 'x ^ (1 + 1) - x - x + 1',
'x ^ 2 - x - x + 1', 'x ^ 2 - x - x + 1',
'x ^ 2 + (1 + 1)(-x) + 1', 'x ^ 2 + (1 + 1) * -x + 1',
'x ^ 2 + 2(-x) + 1', 'x ^ 2 + 2 * -x + 1',
'x ^ 2 - 2x + 1', 'x ^ 2 - 2x + 1',
]) ])
def test_1_4_1(self): def test_1_4_1(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 + 1x', 'x * -1 + 1x',
'-x * 1 + 1x', 'x * -1 + x',
'-x + 1x', '-x * 1 + x',
'-x + x', '-x + x',
'(-1 + 1)x', '(-1 + 1)x',
'0x', '0x',
...@@ -112,23 +113,23 @@ class TestLeidenOefenopgave(TestCase): ...@@ -112,23 +113,23 @@ class TestLeidenOefenopgave(TestCase):
def test_1_4_2(self): def test_1_4_2(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 - 1x', 'x * -1 - 1x',
'-x * 1 - 1x', 'x * -1 - x',
'-x - 1x', '-x * 1 - x',
'-x - x', '-x - x',
'(1 + 1)(-x)', '(1 + 1) * -x',
'2(-x)', '2 * -x',
'-2x', '-2x',
]) ])
def test_1_4_3(self): def test_1_4_3(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 + x * -1', 'x * -1 + x * -1',
'-x * 1 + x(-1)', '-x * 1 + x * -1',
'-x + x(-1)', '-x + x * -1',
'-x - x * 1', '-x - x * 1',
'-x - x', '-x - x',
'(1 + 1)(-x)', '(1 + 1) * -x',
'2(-x)', '2 * -x',
'-2x', '-2x',
]) ])
...@@ -144,19 +145,21 @@ class TestLeidenOefenopgave(TestCase): ...@@ -144,19 +145,21 @@ class TestLeidenOefenopgave(TestCase):
def test_1_7(self): def test_1_7(self):
self.assertRewrite([ self.assertRewrite([
'(4x + 5) * -(5 - 4x)', '(4x + 5) * -(5 - 4x)',
'(4x + 5)(-5 - -4x)', '-(4x + 5)(5 - 4x)',
'(4x + 5)(-5 + 4x)', '-(4x * 5 + 4x * -4x + 5 * 5 + 5 * -4x)',
'4x(-5) + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + 4x * -4x + 5 * 5 + 5 * -4x)',
'(-20)x + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 5 * 5 + 5 * -4x)',
'-20x + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 25 + 5 * -4x)',
'-20x + 16xx + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 25 + (-20)x)',
'-20x + 16x ^ (1 + 1) + 5(-5) + 5 * 4x', '-(20x - 16xx + 25 + (-20)x)',
'-20x + 16x ^ 2 + 5(-5) + 5 * 4x', '-(20x - 16xx + 25 - 20x)',
'-20x + 16x ^ 2 - 25 + 5 * 4x', '-(20x - 16x ^ (1 + 1) + 25 - 20x)',
'-20x + 16x ^ 2 - 25 + 20x', '-(20x - 16x ^ 2 + 25 - 20x)',
'(-1 + 1)20x + 16x ^ 2 - 25', '-((1 - 1)20x - 16x ^ 2 + 25)',
'0 * 20x + 16x ^ 2 - 25', '-(0 * 20x - 16x ^ 2 + 25)',
'0 + 16x ^ 2 - 25', '-(0 - 16x ^ 2 + 25)',
'-(-16x ^ 2 + 25)',
'--16x ^ 2 - 25',
'16x ^ 2 - 25', '16x ^ 2 - 25',
]) ])
...@@ -176,7 +179,7 @@ class TestLeidenOefenopgave(TestCase): ...@@ -176,7 +179,7 @@ class TestLeidenOefenopgave(TestCase):
def test_4_2(self): def test_4_2(self):
self.assertRewrite([ self.assertRewrite([
'2/7 - 4/11', '2 / 7 - 4 / 11',
'22 / 77 - 28 / 77', '22 / 77 - 28 / 77',
'(22 - 28) / 77', '(22 - 28) / 77',
'(-6) / 77', '(-6) / 77',
......
...@@ -16,17 +16,6 @@ from tests.rulestestcase import RulesTestCase as TestCase ...@@ -16,17 +16,6 @@ from tests.rulestestcase import RulesTestCase as TestCase
class TestLeidenOefenopgaveV12(TestCase): class TestLeidenOefenopgaveV12(TestCase):
def test_1_a(self):
self.assertRewrite([
'-5(x^2 - 3x + 6)',
'-(5x ^ 2 + 5(-3x) + 5 * 6)',
'-(5x ^ 2 - 5 * 3x + 5 * 6)',
'-(5x ^ 2 - 15x + 5 * 6)',
'-(5x ^ 2 - 15x + 30)',
'-5x ^ 2 - -15x - 30',
'-5x ^ 2 + 15x - 30',
])
def test_1_d(self): def test_1_d(self):
self.assertRewrite([ self.assertRewrite([
'(2x + x)x', '(2x + x)x',
...@@ -40,30 +29,28 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -40,30 +29,28 @@ class TestLeidenOefenopgaveV12(TestCase):
self.assertRewrite([ self.assertRewrite([
'-2(6x - 4) ^ 2x', '-2(6x - 4) ^ 2x',
'-2(6x - 4)(6x - 4)x', '-2(6x - 4)(6x - 4)x',
'-(2 * 6x + 2(-4))(6x - 4)x', '-(2 * 6x + 2 * -4)(6x - 4)x',
'-(12x + 2(-4))(6x - 4)x', '-(12x + 2 * -4)(6x - 4)x',
'-(12x - 8)(6x - 4)x', '-(12x - 8)(6x - 4)x',
'-(12x - 8)(6xx + (-4)x)', '-(12x - 8)(6xx + (-4)x)',
'-(12x - 8)(6x ^ (1 + 1) + (-4)x)', '-(12x - 8)(6xx - 4x)',
'-(12x - 8)(6x ^ 2 + (-4)x)', '-(12x - 8)(6x ^ (1 + 1) - 4x)',
'-(12x - 8)(6x ^ 2 - 4x)', '-(12x - 8)(6x ^ 2 - 4x)',
'-(12x * 6x ^ 2 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(12x * 6x ^ 2 + 12x * -4x + (-8)6x ^ 2 + (-8) * -4x)',
'-(72xx ^ 2 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + 12x * -4x + (-8)6x ^ 2 + (-8) * -4x)',
'-(72x ^ (1 + 2) + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-8)6x ^ 2 + (-8) * -4x)',
'-(72x ^ 3 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + (-8) * -4x)',
'-(72x ^ 3 - 12x * 4x + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + (--32)x)',
'-(72x ^ 3 - 48xx + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ (1 + 1) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 - 48xx + (-48)x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 + (-48)x ^ 2 + (-8)(-4x))', '-(72x ^ (1 + 2) - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + (-8)(-4x))', '-(72x ^ 3 - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - 8(-4x))', '-(72x ^ 3 - 48x ^ (1 + 1) - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - -8 * 4x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - -32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + 32x)', '-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + 32x)',
'-(72x ^ 3 + (1 + 1)(-48x ^ 2) + 32x)', '-(72x ^ 3 + (1 + 1) * -48x ^ 2 + 32x)',
'-(72x ^ 3 + 2(-48x ^ 2) + 32x)', '-(72x ^ 3 + 2 * -48x ^ 2 + 32x)',
'-(72x ^ 3 - 2 * 48x ^ 2 + 32x)', '-(72x ^ 3 + (-96)x ^ 2 + 32x)',
'-(72x ^ 3 - 96x ^ 2 + 32x)', '-(72x ^ 3 - 96x ^ 2 + 32x)',
'-72x ^ 3 - -96x ^ 2 - 32x', '-72x ^ 3 - -96x ^ 2 - 32x',
'-72x ^ 3 + 96x ^ 2 - 32x', '-72x ^ 3 + 96x ^ 2 - 32x',
...@@ -71,23 +58,23 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -71,23 +58,23 @@ class TestLeidenOefenopgaveV12(TestCase):
def test_2_a(self): def test_2_a(self):
self.assertRewrite([ self.assertRewrite([
'(a ^ 2 * b ^ -1) ^ 3(ab ^ 2)', '(a ^ 2 * b ^ -1) ^ 3 * a b ^ 2',
'(a ^ 2 * 1 / b ^ 1) ^ 3 * ab ^ 2', '(a ^ 2 * 1 / b ^ 1) ^ 3 * a b ^ 2',
'(a ^ 2 * 1 / b) ^ 3 * ab ^ 2', '(a ^ 2 * 1 / b) ^ 3 * a b ^ 2',
'((a ^ 2 * 1) / b) ^ 3 * ab ^ 2', '((a ^ 2 * 1) / b) ^ 3 * a b ^ 2',
'(a ^ 2 / b) ^ 3 * ab ^ 2', '(a ^ 2 / b) ^ 3 * a b ^ 2',
'(a ^ 2) ^ 3 / b ^ 3 * ab ^ 2', '(a ^ 2) ^ 3 / b ^ 3 * a b ^ 2',
'a ^ (2 * 3) / b ^ 3 * ab ^ 2', 'a ^ (2 * 3) / b ^ 3 * a b ^ 2',
'a ^ 6 / b ^ 3 * ab ^ 2', 'a ^ 6 / b ^ 3 * a b ^ 2',
'(a ^ 6 * a) / b ^ 3 * b ^ 2', '(a ^ 6 * a) / b ^ 3 * b ^ 2',
'a ^ (6 + 1) / b ^ 3 * b ^ 2', 'a ^ (6 + 1) / b ^ 3 * b ^ 2',
'a ^ 7 / b ^ 3 * b ^ 2', 'a ^ 7 / b ^ 3 * b ^ 2',
'(a ^ 7 * b ^ 2) / b ^ 3', '(a ^ 7 * b ^ 2) / b ^ 3',
'b ^ 2 / b ^ 3 * a ^ 7 / 1', 'b ^ 2 / b ^ 3 * a ^ 7 / 1',
'b ^ (2 - 3)a ^ 7 / 1', 'b ^ 2 / b ^ 3 * a ^ 7',
'b ^ (-1)a ^ 7 / 1', 'b ^ (2 - 3)a ^ 7',
'1 / b ^ 1 * a ^ 7 / 1', 'b ^ -1 * a ^ 7',
'1 / b * a ^ 7 / 1', '1 / b ^ 1 * a ^ 7',
'1 / b * a ^ 7', '1 / b * a ^ 7',
'(1a ^ 7) / b', '(1a ^ 7) / b',
'a ^ 7 / b', 'a ^ 7 / b',
...@@ -124,9 +111,9 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -124,9 +111,9 @@ class TestLeidenOefenopgaveV12(TestCase):
def test_2_f(self): def test_2_f(self):
self.assertRewrite([ self.assertRewrite([
'(4b) ^ -2', '(4b) ^ -2',
'4 ^ (-2)b ^ (-2)', '4 ^ -2 * b ^ -2',
'1 / 4 ^ 2 * b ^ (-2)', '1 / 4 ^ 2 * b ^ -2',
'1 / 16 * b ^ (-2)', '1 / 16 * b ^ -2',
'1 / 16 * 1 / b ^ 2', '1 / 16 * 1 / b ^ 2',
'(1 * 1) / (16b ^ 2)', '(1 * 1) / (16b ^ 2)',
'1 / (16b ^ 2)', '1 / (16b ^ 2)',
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \ from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
nary_node, get_scope, OP_ADD, infinity, absolute, sin, cos, tan, log, \ nary_node, get_scope, OP_ADD, infinity, absolute, sin, cos, tan, log, \
ln, der, integral, indef, eq, negation_to_node ln, der, integral, indef, eq
from tests.rulestestcase import RulesTestCase, tree from tests.rulestestcase import RulesTestCase, tree
...@@ -121,8 +121,8 @@ class TestNode(RulesTestCase): ...@@ -121,8 +121,8 @@ class TestNode(RulesTestCase):
self.assertEqual(get_scope(plus), self.l) self.assertEqual(get_scope(plus), self.l)
def test_get_scope_negation(self): def test_get_scope_negation(self):
root, a, b, cd = tree('a * b * -cd, a, b, -cd') root, a, b, c, d = tree('ab * -cd, a, b, -c, d')
self.assertEqual(get_scope(root), [a, b, cd]) self.assertEqual(get_scope(root), [a, b, c, d])
def test_get_scope_index(self): def test_get_scope_index(self):
self.assertEqual(self.scope.index(self.a), 0) self.assertEqual(self.scope.index(self.a), 0)
...@@ -244,15 +244,15 @@ class TestNode(RulesTestCase): ...@@ -244,15 +244,15 @@ class TestNode(RulesTestCase):
self.assertTrue(ma.contains(a)) self.assertTrue(ma.contains(a))
def test_construct_function_derivative(self): def test_construct_function_derivative(self):
self.assertEqual(str(tree('der(x ^ 2)')), '[x ^ 2]\'') self.assertEqual(str(tree("(x ^ 2)'")), "[x ^ 2]'")
self.assertEqual(str(tree('der(der(x ^ 2))')), '[x ^ 2]\'\'') self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''")
self.assertEqual(str(tree('der(x ^ 2, x)')), 'd/dx (x ^ 2)') self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2')
def test_construct_function_logarithm(self): def test_construct_function_logarithm(self):
self.assertEqual(str(tree('log(x, e)')), 'ln(x)') self.assertEqual(str(tree('log(x, e)')), 'ln x')
self.assertEqual(str(tree('log(x, 10)')), 'log(x)') self.assertEqual(str(tree('log(x, 10)')), 'log x')
self.assertEqual(str(tree('log(x, 2)')), 'log_2(x)') self.assertEqual(str(tree('log(x, 2)')), 'log_2 x')
self.assertEqual(str(tree('log(x, g)')), 'log(x, g)') self.assertEqual(str(tree('log(x, g)')), 'log_g x')
def test_construct_function_integral(self): def test_construct_function_integral(self):
self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx') self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx')
...@@ -318,9 +318,3 @@ class TestNode(RulesTestCase): ...@@ -318,9 +318,3 @@ class TestNode(RulesTestCase):
def test_eq(self): def test_eq(self):
x, a, b, expect = tree('x, a, b, x + a = b') x, a, b, expect = tree('x, a, b, x + a = b')
self.assertEqual(eq(x + a, b), expect) self.assertEqual(eq(x + a, b), expect)
def test_negation_to_node(self):
a = tree('a')
self.assertEqual(negation_to_node(-a), N('-', a))
self.assertEqual(negation_to_node(-(a + 1)), N('-', a + 1))
self.assertEqual(negation_to_node(-(a - 1)), N('-', a + N('-', L(1))))
...@@ -89,12 +89,13 @@ class TestParser(RulesTestCase): ...@@ -89,12 +89,13 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('sin x'), sin(x)) self.assertEqual(tree('sin x'), sin(x))
self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct? self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct?
self.assertEqual(tree('sin x ^ 2'), sin(x ** 2)) self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2) self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2)) self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
self.assertEqual(tree('sin cos x'), sin(cos(x))) self.assertEqual(tree('sin cos x'), sin(cos(x)))
self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2)) self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
def test_brackets(self): def test_brackets(self):
self.assertEqual(*tree('[x], x')) self.assertEqual(*tree('[x], x'))
...@@ -145,7 +146,7 @@ class TestParser(RulesTestCase): ...@@ -145,7 +146,7 @@ class TestParser(RulesTestCase):
# FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a) # FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a)
def test_integral(self): def test_integral(self):
x, y, dx, a, b, l2 = tree('x, y, dx, a, b, 2') x, y, dx, a, b, l2, oo = tree('x, y, dx, a, b, 2, oo')
self.assertEqual(tree('int x'), integral(x, x)) self.assertEqual(tree('int x'), integral(x, x))
self.assertEqual(tree('int x ^ 2'), integral(x ** 2, x)) self.assertEqual(tree('int x ^ 2'), integral(x ** 2, x))
...@@ -162,21 +163,18 @@ class TestParser(RulesTestCase): ...@@ -162,21 +163,18 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('int_a^(b2) x'), integral(x, x, a, b * 2)) self.assertEqual(tree('int_a^(b2) x'), integral(x, x, a, b * 2))
self.assertEqual(tree('int x ^ 2 + 1'), integral(x ** 2, x) + 1) self.assertEqual(tree('int x ^ 2 + 1'), integral(x ** 2, x) + 1)
self.assertEqual(tree('int x ^ 2 + 1 dx'), integral(x ** 2 + 1, x))
self.assertEqual(tree('int_a^b x ^ 2 dx'), integral(x ** 2, x, a, b)) self.assertEqual(tree('int_a^b x ^ 2 dx'), integral(x ** 2, x, a, b))
self.assertEqual(tree('int_a^(b2) x ^ 2 + 1 dx'), self.assertEqual(tree('int_a x ^ 2 dx'), integral(x ** 2, x, a, oo))
integral(x ** 2 + 1, x, a, b * 2))
self.assertEqual(tree('int_(a^2)^b x ^ 2 + 1 dx'),
integral(x ** 2 + 1, x, a ** 2, b))
self.assertEqual(tree('int_(-a)^b x dx'), integral(x, x, -a, b)) self.assertEqual(tree('int_(-a)^b x dx'), integral(x, x, -a, b))
# FIXME: self.assertEqual(tree('int_-a^b x dx'), integral(x, x, -a, b)) #self.assertEqual(tree('int_-a^b x dx'), integral(x, x, -a, b))
def test_indefinite_integral(self): def test_indefinite_integral(self):
x2, a, b = tree('x ^ 2, a, b') x2, a, b, oo = tree('x ^ 2, a, b, oo')
self.assertEqual(tree('[x ^ 2]_a^b'), indef(x2, a, b)) self.assertEqual(tree('(x ^ 2)_a'), indef(x2, a, oo))
self.assertEqual(tree('(x ^ 2)_a^b'), indef(x2, a, b))
def test_absolute_value(self): def test_absolute_value(self):
x = tree('x') x = tree('x')
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from src.possibilities import MESSAGES, Possibility as P from src.possibilities import MESSAGES, Possibility as P, flatten_mult
from tests.rulestestcase import tree from tests.rulestestcase import tree
from src.parser import Parser from src.parser import Parser
...@@ -80,23 +80,6 @@ class TestPossibilities(unittest.TestCase): ...@@ -80,23 +80,6 @@ class TestPossibilities(unittest.TestCase):
'<Possibility root="3 + 4" handler=add_numerics' \ '<Possibility root="3 + 4" handler=add_numerics' \
' args=(<Scope of "3 + 4">, 3, 4)>') ' args=(<Scope of "3 + 4">, 3, 4)>')
#def test_filter_duplicates(self): def test_flatten_mult(self):
# a, b = ab = tree('a + b') self.assertEqual(flatten_mult(tree('2(xx)')), tree('2xx'))
# p0 = P(a, dummy_handler, (1, 2)) self.assertEqual(flatten_mult(tree('2(xx) + 1')), tree('2xx + 1'))
# p1 = P(ab, dummy_handler, (1, 2))
# p2 = P(ab, dummy_handler, (1, 2, 3))
# p3 = P(ab, dummy_handler_msg, (1, 2))
# self.assertEqual(filter_duplicates([]), [])
# self.assertEqual(filter_duplicates([p0, p1]), [p1])
# self.assertEqual(filter_duplicates([p1, p2]), [p1, p2])
# self.assertEqual(filter_duplicates([p1, p3]), [p1, p3])
# self.assertEqual(filter_duplicates([p0, p1, p2, p3]), [p1, p2, p3])
# # Docstrings example
# (l1, l2), l3 = left, l3 = right = tree('1 + 2 + 3')
# p0 = P(left, add_numerics, (1, 2, 1, 2))
# p1 = P(right, add_numerics, (1, 2, 1, 2))
# p2 = P(right, add_numerics, (1, 3, 1, 3))
# p3 = P(right, add_numerics, (2, 3, 2, 3))
# self.assertEqual(filter_duplicates([p0, p1, p2, p3]), [p1, p2, p3])
...@@ -28,50 +28,50 @@ from tests.rulestestcase import RulesTestCase, tree ...@@ -28,50 +28,50 @@ from tests.rulestestcase import RulesTestCase, tree
class TestRulesDerivatives(RulesTestCase): class TestRulesDerivatives(RulesTestCase):
def test_get_derivation_variable(self): def test_get_derivation_variable(self):
xy0, xy1, x, l1 = tree('der(xy, x), der(xy), der(x), der(1)') xy0, xy1, x, l1 = tree('d/dx xy, (xy)\', x\', 1\'')
self.assertEqual(get_derivation_variable(xy0), 'x') self.assertEqual(get_derivation_variable(xy0), 'x')
self.assertEqual(get_derivation_variable(xy1), 'x') self.assertEqual(get_derivation_variable(xy1), 'x')
self.assertEqual(get_derivation_variable(x), 'x') self.assertEqual(get_derivation_variable(x), 'x')
self.assertIsNone(get_derivation_variable(l1)) self.assertIsNone(get_derivation_variable(l1))
def test_match_zero_derivative(self): def test_match_zero_derivative(self):
root = tree('der(x, y)') root = tree('d/dy x')
self.assertEqualPos(match_zero_derivative(root), self.assertEqualPos(match_zero_derivative(root),
[P(root, zero_derivative)]) [P(root, zero_derivative)])
root = tree('der(2)') root = tree('d/dx 2')
self.assertEqualPos(match_zero_derivative(root), self.assertEqualPos(match_zero_derivative(root),
[P(root, zero_derivative)]) [P(root, zero_derivative)])
def test_zero_derivative(self): def test_zero_derivative(self):
root = tree('der(1)') root = tree('d/dx 1')
self.assertEqual(zero_derivative(root, ()), 0) self.assertEqual(zero_derivative(root, ()), 0)
def test_match_one_derivative(self): def test_match_one_derivative(self):
root = tree('der(x)') root = tree('d/dx x')
self.assertEqualPos(match_one_derivative(root), self.assertEqualPos(match_one_derivative(root),
[P(root, one_derivative)]) [P(root, one_derivative)])
root = tree('der(x, x)') root = tree('d/dx x')
self.assertEqualPos(match_one_derivative(root), self.assertEqualPos(match_one_derivative(root),
[P(root, one_derivative)]) [P(root, one_derivative)])
def test_one_derivative(self): def test_one_derivative(self):
root = tree('der(x)') root = tree('d/dx x')
self.assertEqual(one_derivative(root, ()), 1) self.assertEqual(one_derivative(root, ()), 1)
def test_match_const_deriv_multiplication(self): def test_match_const_deriv_multiplication(self):
root = tree('der(2x)') root = tree('d/dx 2x')
l2, x = root[0] l2, x = root[0]
self.assertEqualPos(match_const_deriv_multiplication(root), self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), l2, x))]) [P(root, const_deriv_multiplication, (Scope(root[0]), l2, x))])
(x, y), x = root = tree('der(xy, x)') (x, y), x = root = tree('d/dx xy')
self.assertEqualPos(match_const_deriv_multiplication(root), self.assertEqualPos(match_const_deriv_multiplication(root),
[P(root, const_deriv_multiplication, (Scope(root[0]), y, x))]) [P(root, const_deriv_multiplication, (Scope(root[0]), y, x))])
def test_match_const_deriv_multiplication_multiple_constants(self): def test_match_const_deriv_multiplication_multiple_constants(self):
root = tree('der(2x * 3)') root = tree('d/dx 2x * 3')
(l2, x), l3 = root[0] (l2, x), l3 = root[0]
scope = Scope(root[0]) scope = Scope(root[0])
self.assertEqualPos(match_const_deriv_multiplication(root), self.assertEqualPos(match_const_deriv_multiplication(root),
...@@ -79,33 +79,33 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -79,33 +79,33 @@ class TestRulesDerivatives(RulesTestCase):
P(root, const_deriv_multiplication, (scope, l3, x))]) P(root, const_deriv_multiplication, (scope, l3, x))])
def test_const_deriv_multiplication(self): def test_const_deriv_multiplication(self):
root = tree('der(2x)') root = tree('d/dx 2x')
l2, x = root[0] l2, x = root[0]
args = Scope(root[0]), l2, x args = Scope(root[0]), l2, x
self.assertEqual(const_deriv_multiplication(root, args), self.assertEqual(const_deriv_multiplication(root, args),
l2 * der(x, x)) l2 * der(x, x))
def test_match_variable_power(self): def test_match_variable_power(self):
root, x, l2 = tree('der(x ^ 2), x, 2') root, x, l2 = tree('d/dx x ^ 2, x, 2')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
[P(root, variable_root)]) [P(root, variable_root)])
root = tree('der(2 ^ x)') root = tree('d/dx 2 ^ x')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
[P(root, variable_exponent)]) [P(root, variable_exponent)])
def test_match_variable_power_chain_rule(self): def test_match_variable_power_chain_rule(self):
root, x, l2, x3 = tree('der((x ^ 3) ^ 2), x, 2, x ^ 3') root, x, l2, x3 = tree('d/dx (x ^ 3) ^ 2, x, 2, x ^ 3')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
[P(root, chain_rule, (x3, variable_root, ()))]) [P(root, chain_rule, (x3, variable_root, ()))])
root = tree('der(2 ^ x ^ 3)') root = tree('d/dx 2 ^ x ^ 3')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
[P(root, chain_rule, (x3, variable_exponent, ()))]) [P(root, chain_rule, (x3, variable_exponent, ()))])
# Below is not mathematically underivable, it's just not within the # Below is not mathematically underivable, it's just not within the
# scope of our program # scope of our program
root, x = tree('der(x ^ x), x') root, x = tree('d/dx x ^ x, x')
self.assertEqualPos(match_variable_power(root), self.assertEqualPos(match_variable_power(root),
[P(root, power_rule)]) [P(root, power_rule)])
...@@ -116,138 +116,138 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -116,138 +116,138 @@ class TestRulesDerivatives(RulesTestCase):
def test_power_rule_chain(self): def test_power_rule_chain(self):
self.assertRewrite([ self.assertRewrite([
"[x ^ x]'", "[x ^ x]'",
"[e ^ ln(x ^ x)]'", "[e ^ (ln x ^ x)]'",
"e ^ ln(x ^ x)[ln(x ^ x)]'", "e ^ (ln x ^ x)[ln x ^ x]'",
"x ^ x * [ln(x ^ x)]'", "x ^ x * [ln x ^ x]'",
"x ^ x * [xln(x)]'", "x ^ x * [x ln x]'",
"x ^ x * ([x]' * ln(x) + x[ln(x)]')", "x ^ x * ([x]' * ln x + x[ln x]')",
"x ^ x * (1ln(x) + x[ln(x)]')", "x ^ x * (1ln x + x[ln x]')",
"x ^ x * (ln(x) + x[ln(x)]')", "x ^ x * (ln x + x[ln x]')",
"x ^ x * (ln(x) + x * 1 / x)", "x ^ x * (ln x + x * 1 / x)",
"x ^ x * (ln(x) + (x * 1) / x)", "x ^ x * (ln x + (x * 1) / x)",
"x ^ x * (ln(x) + x / x)", "x ^ x * (ln x + x / x)",
"x ^ x * (ln(x) + 1)", "x ^ x * (ln x + 1)",
"x ^ x * ln(x) + x ^ x * 1", "x ^ x * ln x + x ^ x * 1",
"x ^ x * ln(x) + x ^ x", "x ^ x * ln x + x ^ x",
]) ])
def test_variable_root(self): def test_variable_root(self):
root = tree('der(x ^ 2)') root = tree('d/dx x ^ 2')
x, n = root[0] x, n = root[0]
self.assertEqual(variable_root(root, ()), n * x ** (n - 1)) self.assertEqual(variable_root(root, ()), n * x ** (n - 1))
def test_variable_exponent(self): def test_variable_exponent(self):
root = tree('der(2 ^ x)') root = tree('d/dx 2 ^ x')
g, x = root[0] g, x = root[0]
self.assertEqual(variable_exponent(root, ()), g ** x * ln(g)) self.assertEqual(variable_exponent(root, ()), g ** x * ln(g))
root = tree('der(e ^ x)') root = tree('d/dx e ^ x')
e, x = root[0] e, x = root[0]
self.assertEqual(variable_exponent(root, ()), e ** x) self.assertEqual(variable_exponent(root, ()), e ** x)
def test_chain_rule(self): def test_chain_rule(self):
root = tree('der(2 ^ x ^ 3)') root = tree('(2 ^ x ^ 3)\'')
l2, x3 = root[0] l2, x3 = root[0]
x, l3 = x3 x, l3 = x3
self.assertEqual(chain_rule(root, (x3, variable_exponent, ())), self.assertEqual(chain_rule(root, (x3, variable_exponent, ())),
l2 ** x3 * ln(l2) * der(x3)) l2 ** x3 * ln(l2) * der(x3))
def test_match_logarithmic(self): def test_match_logarithmic(self):
root = tree('der(log(x))') root = tree('d/dx log(x)')
self.assertEqualPos(match_logarithmic(root), [P(root, logarithmic)]) self.assertEqualPos(match_logarithmic(root), [P(root, logarithmic)])
def test_match_logarithmic_chain_rule(self): def test_match_logarithmic_chain_rule(self):
root, f = tree('der(log(x ^ 2)), x ^ 2') root, f = tree('d/dx log(x ^ 2), x ^ 2')
self.assertEqualPos(match_logarithmic(root), self.assertEqualPos(match_logarithmic(root),
[P(root, chain_rule, (f, logarithmic, ()))]) [P(root, chain_rule, (f, logarithmic, ()))])
def test_logarithmic(self): def test_logarithmic(self):
root, x, l1, l10 = tree('der(log(x)), x, 1, 10') root, x, l1, l10 = tree('d/dx log(x), x, 1, 10')
self.assertEqual(logarithmic(root, ()), l1 / (x * ln(l10))) self.assertEqual(logarithmic(root, ()), l1 / (x * ln(l10)))
root, x, l1, l10 = tree('der(ln(x)), x, 1, 10') root, x, l1, l10 = tree('d/dx ln(x), x, 1, 10')
self.assertEqual(logarithmic(root, ()), l1 / x) self.assertEqual(logarithmic(root, ()), l1 / x)
def test_match_goniometric(self): def test_match_goniometric(self):
root = tree('der(sin(x))') root = tree('d/dx sin(x)')
self.assertEqualPos(match_goniometric(root), [P(root, sinus)]) self.assertEqualPos(match_goniometric(root), [P(root, sinus)])
root = tree('der(cos(x))') root = tree('d/dx cos(x)')
self.assertEqualPos(match_goniometric(root), [P(root, cosinus)]) self.assertEqualPos(match_goniometric(root), [P(root, cosinus)])
root = tree('der(tan(x))') root = tree('d/dx tan(x)')
self.assertEqualPos(match_goniometric(root), [P(root, tangens)]) self.assertEqualPos(match_goniometric(root), [P(root, tangens)])
def test_match_goniometric_chain_rule(self): def test_match_goniometric_chain_rule(self):
root, x2 = tree('der(sin(x ^ 2)), x ^ 2') root, x2 = tree('d/dx sin(x ^ 2), x ^ 2')
self.assertEqualPos(match_goniometric(root), self.assertEqualPos(match_goniometric(root),
[P(root, chain_rule, (x2, sinus, ()))]) [P(root, chain_rule, (x2, sinus, ()))])
root = tree('der(cos(x ^ 2))') root = tree('d/dx cos(x ^ 2)')
self.assertEqualPos(match_goniometric(root), self.assertEqualPos(match_goniometric(root),
[P(root, chain_rule, (x2, cosinus, ()))]) [P(root, chain_rule, (x2, cosinus, ()))])
def test_sinus(self): def test_sinus(self):
root, x = tree('der(sin(x)), x') root, x = tree('d/dx sin(x), x')
self.assertEqual(sinus(root, ()), cos(x)) self.assertEqual(sinus(root, ()), cos(x))
def test_cosinus(self): def test_cosinus(self):
root, x = tree('der(cos(x)), x') root, x = tree('d/dx cos(x), x')
self.assertEqual(cosinus(root, ()), -sin(x)) self.assertEqual(cosinus(root, ()), -sin(x))
def test_tangens(self): def test_tangens(self):
root, x = tree('der(tan(x), x), x') root, x = tree('d/dx tan(x), x')
self.assertEqual(tangens(root, ()), der(sin(x) / cos(x), x)) self.assertEqual(tangens(root, ()), der(sin(x) / cos(x), x))
root = tree('der(tan(x))') root = tree('tan(x)\'')
self.assertEqual(tangens(root, ()), der(sin(x) / cos(x))) self.assertEqual(tangens(root, ()), der(sin(x) / cos(x)))
def test_match_sum_product_rule_sum(self): def test_match_sum_product_rule_sum(self):
root = tree('der(x ^ 2 + x)') root = tree('d/dx (x ^ 2 + x)')
x2, x = f = root[0] x2, x = f = root[0]
self.assertEqualPos(match_sum_product_rule(root), self.assertEqualPos(match_sum_product_rule(root),
[P(root, sum_rule, (Scope(f), x2)), [P(root, sum_rule, (Scope(f), x2)),
P(root, sum_rule, (Scope(f), x))]) P(root, sum_rule, (Scope(f), x))])
root = tree('der(x ^ 2 + 3 + x)') root = tree('d/dx (x ^ 2 + 3 + x)')
self.assertEqualPos(match_sum_product_rule(root), self.assertEqualPos(match_sum_product_rule(root),
[P(root, sum_rule, (Scope(root[0]), x2)), [P(root, sum_rule, (Scope(root[0]), x2)),
P(root, sum_rule, (Scope(root[0]), x))]) P(root, sum_rule, (Scope(root[0]), x))])
def test_match_sum_product_rule_product(self): def test_match_sum_product_rule_product(self):
root = tree('der(x ^ 2 * x)') root = tree('d/dx x ^ 2 * x')
x2, x = f = root[0] x2, x = f = root[0]
self.assertEqualPos(match_sum_product_rule(root), self.assertEqualPos(match_sum_product_rule(root),
[P(root, product_rule, (Scope(f), x2)), [P(root, product_rule, (Scope(f), x2)),
P(root, product_rule, (Scope(f), x))]) P(root, product_rule, (Scope(f), x))])
def test_match_sum_product_rule_none(self): def test_match_sum_product_rule_none(self):
root = tree('der(2 + 2)') root = tree('d/dx (2 + 2)')
self.assertEqualPos(match_sum_product_rule(root), []) self.assertEqualPos(match_sum_product_rule(root), [])
root = tree('der(x ^ 2 * 2)') root = tree('d/dx x ^ 2 * 2')
self.assertEqualPos(match_sum_product_rule(root), []) self.assertEqualPos(match_sum_product_rule(root), [])
def test_sum_rule(self): def test_sum_rule(self):
root = tree('der(x ^ 2 + x)') root = tree('(x ^ 2 + x)\'')
x2, x = f = root[0] x2, x = f = root[0]
self.assertEqual(sum_rule(root, (Scope(f), x2)), der(x2) + der(x)) self.assertEqual(sum_rule(root, (Scope(f), x2)), der(x2) + der(x))
self.assertEqual(sum_rule(root, (Scope(f), x)), der(x) + der(x2)) self.assertEqual(sum_rule(root, (Scope(f), x)), der(x) + der(x2))
root = tree('der(x ^ 2 + 3 + x)') root = tree('(x ^ 2 + 3 + x)\'')
(x2, l3), x = f = root[0] (x2, l3), x = f = root[0]
self.assertEqual(sum_rule(root, (Scope(f), x2)), der(x2) + der(l3 + x)) self.assertEqual(sum_rule(root, (Scope(f), x2)), der(x2) + der(l3 + x))
self.assertEqual(sum_rule(root, (Scope(f), x)), der(x) + der(x2 + l3)) self.assertEqual(sum_rule(root, (Scope(f), x)), der(x) + der(x2 + l3))
def test_product_rule(self): def test_product_rule(self):
root = tree('der(x ^ 2 * x)') root = tree('(x ^ 2 * x)\'')
x2, x = f = root[0] x2, x = f = root[0]
self.assertEqual(product_rule(root, (Scope(f), x2)), self.assertEqual(product_rule(root, (Scope(f), x2)),
der(x2) * x + x2 * der(x)) der(x2) * x + x2 * der(x))
self.assertEqual(product_rule(root, (Scope(f), x)), self.assertEqual(product_rule(root, (Scope(f), x)),
der(x) * x2 + x * der(x2)) der(x) * x2 + x * der(x2))
root = tree('der(x ^ 2 * x * x ^ 3)') root = tree('(x ^ 2 * x * x ^ 3)\'')
(x2, x), x3 = f = root[0] (x2, x), x3 = f = root[0]
self.assertEqual(product_rule(root, (Scope(f), x2)), self.assertEqual(product_rule(root, (Scope(f), x2)),
der(x2) * (x * x3) + x2 * der(x * x3)) der(x2) * (x * x3) + x2 * der(x * x3))
...@@ -257,15 +257,15 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -257,15 +257,15 @@ class TestRulesDerivatives(RulesTestCase):
der(x3) * (x2 * x) + x3 * der(x2 * x)) der(x3) * (x2 * x) + x3 * der(x2 * x))
def test_match_quotient_rule(self): def test_match_quotient_rule(self):
root = tree('der(x ^ 2 / x)') root = tree('d/dx x ^ 2 / x')
self.assertEqualPos(match_quotient_rule(root), self.assertEqualPos(match_quotient_rule(root),
[P(root, quotient_rule)]) [P(root, quotient_rule)])
root = tree('der(x ^ 2 / 2)') root = tree('d/dx x ^ 2 / 2')
self.assertEqualPos(match_quotient_rule(root), []) self.assertEqualPos(match_quotient_rule(root), [])
def test_quotient_rule(self): def test_quotient_rule(self):
root = tree('der(x ^ 2 / x)') root = tree('(x ^ 2 / x)\'')
f, g = root[0] f, g = root[0]
self.assertEqual(quotient_rule(root, ()), self.assertEqual(quotient_rule(root, ()),
(der(f) * g - f * der(g)) / g ** 2) (der(f) * g - f * der(g)) / g ** 2)
......
...@@ -320,17 +320,17 @@ class TestRulesFractions(RulesTestCase): ...@@ -320,17 +320,17 @@ class TestRulesFractions(RulesTestCase):
'1 / (1 / b - 1 / a)', '1 / (1 / b - 1 / a)',
'(b * 1) / (b(1 / b - 1 / a))', '(b * 1) / (b(1 / b - 1 / a))',
'b / (b(1 / b - 1 / a))', 'b / (b(1 / b - 1 / a))',
'b / (b * 1 / b + b(-1 / a))', 'b / (b * 1 / b + b * -1 / a)',
'b / ((b * 1) / b + b(-1 / a))', 'b / (b * 1 / b - b * 1 / a)',
'b / (b / b + b(-1 / a))', 'b / ((b * 1) / b - b * 1 / a)',
'b / (1 + b(-1 / a))', 'b / (b / b - b * 1 / a)',
'b / (1 - b * 1 / a)', 'b / (1 - b * 1 / a)',
'b / (1 - (b * 1) / a)', 'b / (1 - (b * 1) / a)',
'b / (1 - b / a)', 'b / (1 - b / a)',
'(ab) / (a(1 - b / a))', '(ab) / (a(1 - b / a))',
'(ab) / (a * 1 + a(-b / a))', '(ab) / (a * 1 + a * -b / a)',
'(ab) / (a + a(-b / a))', '(ab) / (a + a * -b / a)',
'(ab) / (a - ab / a)', '(ab) / (a - a b / a)',
'(ab) / (a - (ab) / a)', '(ab) / (a - (ab) / a)',
'(ab) / (a - b)', '(ab) / (a - b)',
]) ])
...@@ -30,30 +30,30 @@ class TestRulesGoniometry(RulesTestCase): ...@@ -30,30 +30,30 @@ class TestRulesGoniometry(RulesTestCase):
self.assertEqual(doctest.testmod(m=goniometry)[0], 0) self.assertEqual(doctest.testmod(m=goniometry)[0], 0)
def test_match_add_quadrants(self): def test_match_add_quadrants(self):
s, c = root = tree('sin(t) ^ 2 + cos(t) ^ 2') s, c = root = tree('sin^2 t + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
c, s = root = tree('cos(t) ^ 2 + sin(t) ^ 2') c, s = root = tree('cos^2 t + sin^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
(s, a), c = root = tree('sin(t) ^ 2 + a + cos(t) ^ 2') (s, a), c = root = tree('sin^2 t + a + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
(s, c0), c1 = root = tree('sin(t) ^ 2 + cos(t) ^ 2 + cos(t) ^ 2') (s, c0), c1 = root = tree('sin^2 t + cos^2 t + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c0)), [P(root, add_quadrants, (Scope(root), s, c0)),
P(root, add_quadrants, (Scope(root), s, c1))]) P(root, add_quadrants, (Scope(root), s, c1))])
root = tree('sin(t) ^ 2 + cos(y) ^ 2') root = tree('sin^2 t + cos^2 y')
self.assertEqualPos(match_add_quadrants(root), []) self.assertEqualPos(match_add_quadrants(root), [])
root = tree('sin(t) ^ 2 - cos(t) ^ 2') root = tree('sin^2 t - cos^2 t')
self.assertEqualPos(match_add_quadrants(root), []) self.assertEqualPos(match_add_quadrants(root), [])
s, c = root = tree('-sin(t) ^ 2 - cos(t) ^ 2') s, c = root = tree('-sin^2 t - cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, factor_out_quadrant_negation, (Scope(root), s, c))]) [P(root, factor_out_quadrant_negation, (Scope(root), s, c))])
......
...@@ -141,9 +141,9 @@ class TestRulesIntegrals(RulesTestCase): ...@@ -141,9 +141,9 @@ class TestRulesIntegrals(RulesTestCase):
self.assertRewrite([ self.assertRewrite([
'int a / x', 'int a / x',
'int a * 1 / x dx', 'int a * 1 / x dx',
'aint 1 / x dx', 'a(int 1 / x dx)',
'a(ln|x| + C)', 'a(ln|x| + C)',
'aln|x| + aC', 'a ln|x| + aC',
# FIXME: 'aln|x| + C', # ac -> C # FIXME: 'aln|x| + C', # ac -> C
]) ])
...@@ -176,24 +176,24 @@ class TestRulesIntegrals(RulesTestCase): ...@@ -176,24 +176,24 @@ class TestRulesIntegrals(RulesTestCase):
self.assertEqual(cosinus_integral(root, ()), expect) self.assertEqual(cosinus_integral(root, ()), expect)
def test_match_sum_rule_integral(self): def test_match_sum_rule_integral(self):
(f, g), x = root = tree('int 2x + 3x dx') (f, g), x = root = tree('int (2x + 3x) dx')
self.assertEqualPos(match_sum_rule_integral(root), self.assertEqualPos(match_sum_rule_integral(root),
[P(root, sum_rule_integral, (Scope(root[0]), f))]) [P(root, sum_rule_integral, (Scope(root[0]), f))])
((f, g), h), x = root = tree('int 2x + 3x + 4x dx') ((f, g), h), x = root = tree('int (2x + 3x + 4x) dx')
self.assertEqualPos(match_sum_rule_integral(root), self.assertEqualPos(match_sum_rule_integral(root),
[P(root, sum_rule_integral, (Scope(root[0]), f)), [P(root, sum_rule_integral, (Scope(root[0]), f)),
P(root, sum_rule_integral, (Scope(root[0]), g)), P(root, sum_rule_integral, (Scope(root[0]), g)),
P(root, sum_rule_integral, (Scope(root[0]), h))]) P(root, sum_rule_integral, (Scope(root[0]), h))])
def test_sum_rule_integral(self): def test_sum_rule_integral(self):
((f, g), h), x = root = tree('int 2x + 3x + 4x dx') ((f, g), h), x = root = tree('int (2x + 3x + 4x) dx')
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), f)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), f)),
tree('int 2x dx + int 3x + 4x dx')) tree('int 2x dx + int (3x + 4x) dx'))
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), g)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), g)),
tree('int 3x dx + int 2x + 4x dx')) tree('int 3x dx + int (2x + 4x) dx'))
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), h)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), h)),
tree('int 4x dx + int 2x + 3x dx')) tree('int 4x dx + int (2x + 3x) dx'))
def test_match_remove_indef_constant(self): def test_match_remove_indef_constant(self):
Fx, a, b = root = tree('[2x + C]_a^b') Fx, a, b = root = tree('[2x + C]_a^b')
......
...@@ -98,8 +98,8 @@ class TestRulesLineq(RulesTestCase): ...@@ -98,8 +98,8 @@ class TestRulesLineq(RulesTestCase):
'2x = -3x - 5', '2x = -3x - 5',
'2x - -3x = -3x - 5 - -3x', '2x - -3x = -3x - 5 - -3x',
'2x + 3x = -3x - 5 - -3x', '2x + 3x = -3x - 5 - -3x',
'(2 + 3)x = -3x - 5 - -3x', '2x + 3x = -3x - 5 + 3x',
'5x = -3x - 5 - -3x', '(2 + 3)x = -3x - 5 + 3x',
'5x = -3x - 5 + 3x', '5x = -3x - 5 + 3x',
'5x = (-1 + 1)3x - 5', '5x = (-1 + 1)3x - 5',
'5x = 0 * 3x - 5', '5x = 0 * 3x - 5',
...@@ -116,11 +116,11 @@ class TestRulesLineq(RulesTestCase): ...@@ -116,11 +116,11 @@ class TestRulesLineq(RulesTestCase):
def test_match_move_term_chain_advanced(self): def test_match_move_term_chain_advanced(self):
self.assertRewrite([ self.assertRewrite([
'-x = a', '-x = a',
'(-x)(-1) = a(-1)', '(-x) * -1 = a * -1',
'-x(-1) = a(-1)', '-x * -1 = a * -1',
'--x * 1 = a(-1)', '--x * 1 = a * -1',
'--x = a(-1)', '--x = a * -1',
'x = a(-1)', 'x = a * -1',
'x = -a * 1', 'x = -a * 1',
'x = -a', 'x = -a',
]) ])
......
...@@ -106,8 +106,8 @@ class TestRulesNegation(RulesTestCase): ...@@ -106,8 +106,8 @@ class TestRulesNegation(RulesTestCase):
def test_double_negated_division(self): def test_double_negated_division(self):
self.assertRewrite([ self.assertRewrite([
'(-a) / (-b)', '(-a) / -b',
'-a / (-b)', '-a / -b',
'--a / b', '--a / b',
'a / b', 'a / b',
]) ])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment