Commit 72f94a0f authored by Taddeüs Kroes's avatar Taddeüs Kroes

Significantly improved the parser.

- Thought over and reconfigured precedences. These are now also synchronized
  with precedences in the line printer.
- The subscript operator ("_") has been added to the grammar similarly to
  exponentiation ("^"). This resolved some issues with the use of operators in
  integral bounds.
- Unary negation now has a higher precedence, but the parser moves negations up
  in the tree as far as possible (within the bounds of correctness and without
  causing cycles during reduction).
- All unit tests have been updated using the new syntax.
- graph_drawing has been updated to include the new syntax of the line printer.
parent 9cbf75f1
graph_drawing @ e17779f2
Subproject commit 107e93a1697600860f4e02175f47cf8650affd0f Subproject commit e17779f211ca67336c3a6ac9038f552f91ff2f79
...@@ -21,7 +21,7 @@ import re ...@@ -21,7 +21,7 @@ import re
sys.path.insert(0, os.path.realpath('external')) sys.path.insert(0, os.path.realpath('external'))
from graph_drawing.graph import generate_graph from graph_drawing.graph import generate_graph
from graph_drawing.line import generate_line from graph_drawing.line import generate_line, preprocess_node
from graph_drawing.node import Node, Leaf from graph_drawing.node import Node, Leaf
...@@ -70,6 +70,16 @@ OP_REWRITE_ALL = 24 ...@@ -70,6 +70,16 @@ OP_REWRITE_ALL = 24
OP_REWRITE_ALL_VERBOSE = 25 OP_REWRITE_ALL_VERBOSE = 25
OP_REWRITE = 26 OP_REWRITE = 26
# Different types of derivative
OP_PRIME = 27
OP_DXDER = 28
OP_PARENS = 29
OP_BRACKETS = 30
OP_CBRACKETS = 31
UNARY_FUNCTIONS = [OP_INT, OP_DXDER, OP_LOG]
# Special identifiers # Special identifiers
PI = 'pi' PI = 'pi'
E = 'e' E = 'e'
...@@ -102,7 +112,7 @@ OP_MAP = { ...@@ -102,7 +112,7 @@ OP_MAP = {
'tan': OP_TAN, 'tan': OP_TAN,
'sqrt': OP_SQRT, 'sqrt': OP_SQRT,
'int': OP_INT, 'int': OP_INT,
'der': OP_DER, '\'': OP_PRIME,
'solve': OP_SOLVE, 'solve': OP_SOLVE,
'log': OP_LOG, 'log': OP_LOG,
'=': OP_EQ, '=': OP_EQ,
...@@ -114,9 +124,13 @@ OP_MAP = { ...@@ -114,9 +124,13 @@ OP_MAP = {
} }
OP_VALUE_MAP = dict([(v, k) for k, v in OP_MAP.iteritems()]) OP_VALUE_MAP = dict([(v, k) for k, v in OP_MAP.iteritems()])
OP_MAP['ln'] = OP_LOG
OP_VALUE_MAP[OP_INT_INDEF] = 'indef' OP_VALUE_MAP[OP_INT_INDEF] = 'indef'
OP_VALUE_MAP[OP_ABS] = 'abs' OP_VALUE_MAP[OP_ABS] = '||'
OP_VALUE_MAP[OP_DXDER] = 'd/d'
OP_VALUE_MAP[OP_PARENS] = '()'
OP_VALUE_MAP[OP_BRACKETS] = '[]'
OP_VALUE_MAP[OP_CBRACKETS] = '{}'
OP_MAP['ln'] = OP_LOG
TOKEN_MAP = { TOKEN_MAP = {
OP_COMMA: 'COMMA', OP_COMMA: 'COMMA',
...@@ -133,9 +147,10 @@ TOKEN_MAP = { ...@@ -133,9 +147,10 @@ TOKEN_MAP = {
OP_COS: 'FUNCTION', OP_COS: 'FUNCTION',
OP_TAN: 'FUNCTION', OP_TAN: 'FUNCTION',
OP_INT: 'INTEGRAL', OP_INT: 'INTEGRAL',
OP_DER: 'FUNCTION', OP_DXDER: 'DERIVATIVE',
OP_PRIME: 'PRIME',
OP_SOLVE: 'FUNCTION', OP_SOLVE: 'FUNCTION',
OP_LOG: 'FUNCTION', OP_LOG: 'LOGARITHM',
OP_EQ: 'EQ', OP_EQ: 'EQ',
OP_POSSIBILITIES: 'POSSIBILITIES', OP_POSSIBILITIES: 'POSSIBILITIES',
OP_HINT: 'HINT', OP_HINT: 'HINT',
...@@ -152,6 +167,11 @@ def to_expression(obj): ...@@ -152,6 +167,11 @@ def to_expression(obj):
return ExpressionLeaf(obj) return ExpressionLeaf(obj)
def bounds_str(f, a, b):
left = str(ExpressionNode(OP_SUBSCRIPT, f, a, no_spacing=True))
return left + str(ExpressionNode(OP_POW, Leaf(1), b, no_spacing=True))[1:]
class ExpressionBase(object): class ExpressionBase(object):
def __lt__(self, other): def __lt__(self, other):
""" """
...@@ -208,7 +228,8 @@ class ExpressionBase(object): ...@@ -208,7 +228,8 @@ class ExpressionBase(object):
return copy.deepcopy(self) return copy.deepcopy(self)
def is_op(self, *ops): def is_op(self, *ops):
return not self.is_leaf and self.op in ops return not self.is_leaf and (self.op in ops or
(self.op in (OP_DXDER, OP_PRIME) and OP_DER in ops))
def is_power(self, exponent=None): def is_power(self, exponent=None):
if self.is_leaf or self.op != OP_POW: if self.is_leaf or self.op != OP_POW:
...@@ -266,9 +287,9 @@ class ExpressionBase(object): ...@@ -266,9 +287,9 @@ class ExpressionBase(object):
return self.negate(-n) return self.negate(-n)
def negate(self, n=1): def negate(self, n=1, clone=True):
"""Negate the node n times.""" """Negate the node n times."""
return negate(self, self.negated + n, clone=True) return negate(self, self.negated + n, clone=clone)
def contains(self, node, include_self=True): def contains(self, node, include_self=True):
""" """
...@@ -290,6 +311,7 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -290,6 +311,7 @@ class ExpressionNode(Node, ExpressionBase):
super(ExpressionNode, self).__init__(*args, **kwargs) super(ExpressionNode, self).__init__(*args, **kwargs)
self.type = TYPE_OPERATOR self.type = TYPE_OPERATOR
op = args[0] op = args[0]
self.parens = False
if isinstance(op, str): if isinstance(op, str):
self.value = op self.value = op
...@@ -298,38 +320,6 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -298,38 +320,6 @@ class ExpressionNode(Node, ExpressionBase):
self.value = OP_VALUE_MAP[op] self.value = OP_VALUE_MAP[op]
self.op = op self.op = op
def construct_derivative(self, children):
f = children[0]
if len(children) < 2:
# der(der(x ^ 2)) -> [x ^ 2]''
if self[0].is_op(OP_DER) and len(self[0]) < 2:
return f + '\''
# der(x ^ 2) -> [x ^ 2]'
return '[' + f + ']\''
# der(x ^ 2, x) -> d/dx (x ^ 2)
return 'd/d%s (%s)' % (children[1], f)
def construct_logarithm(self, children):
if self[0].is_op(OP_ABS):
content = children[0]
else:
content = '(' + children[0] + ')'
# log(a, e) -> ln(a)
if self[1].is_identifier(E):
return 'ln%s' % content
# log(a, 10) -> log(a)
if self[1] == 10:
return 'log%s' % content
# log(a, 2) -> log_2(a)
if children[1].isdigit():
return 'log_%s%s' % (children[1], content)
def construct_integral(self, children): def construct_integral(self, children):
# Make sure that any needed parentheses around f(x) are generated, # Make sure that any needed parentheses around f(x) are generated,
# and append ' dx' to it (result 'f(x) dx') # and append ' dx' to it (result 'f(x) dx')
...@@ -375,8 +365,8 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -375,8 +365,8 @@ class ExpressionNode(Node, ExpressionBase):
return '|%s|' % children[0] return '|%s|' % children[0]
constructors = { constructors = {
OP_DER: self.construct_derivative, #OP_DER: self.construct_derivative,
OP_LOG: self.construct_logarithm, #OP_LOG: self.construct_logarithm,
OP_INT: self.construct_integral, OP_INT: self.construct_integral,
OP_INT_INDEF: self.construct_indef_integral OP_INT_INDEF: self.construct_indef_integral
} }
...@@ -393,9 +383,56 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -393,9 +383,56 @@ class ExpressionNode(Node, ExpressionBase):
and len(self) == 1 and self[0].is_op(OP_ABS): and len(self) == 1 and self[0].is_op(OP_ABS):
return self.title() + children[0] return self.title() + children[0]
def arity(self):
if self.op in UNARY_FUNCTIONS:
return 1
if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE):
return 1
return len(self)
def operator(self):
if self.op == OP_LOG:
base = self[1].value
if base == DEFAULT_LOGARITHM_BASE:
return self.value
if base == E:
return 'ln'
return '%s_%s' % (self.value, str(self[1]))
if self.op == OP_DXDER:
return self.value + str(self[1])
if self.op == OP_INT and len(self) == 4:
return bounds_str(Leaf('int'), self[2], self[3])
return self.value
def is_postfix(self):
return self.op in (OP_PRIME, OP_INT_INDEF)
def __str__(self): # pragma: nocover def __str__(self): # pragma: nocover
return generate_line(self) return generate_line(self)
def custom_line(self):
if self.op == OP_INT_INDEF:
Fx, a, b = self
return bounds_str(ExpressionNode(OP_BRACKETS, Fx), a, b)
def preprocess_str_exp(self):
if self.op == OP_PRIME and not self[0].is_op(OP_PRIME):
self[0] = ExpressionNode(OP_BRACKETS, self[0])
def postprocess_str(self, s):
if self.op == OP_INT:
return '%s d%s' % (s, self[1])
return s
def __eq__(self, other): def __eq__(self, other):
""" """
Check strict equivalence. Check strict equivalence.
...@@ -407,7 +444,7 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -407,7 +444,7 @@ class ExpressionNode(Node, ExpressionBase):
self.nodes[self.nodes.index(old_child)] = new_child self.nodes[self.nodes.index(old_child)] = new_child
def graph(self): # pragma: nocover def graph(self): # pragma: nocover
return generate_graph(negation_to_node(self)) return generate_graph(preprocess_node(self))
def extract_polynome_properties(self): def extract_polynome_properties(self):
""" """
...@@ -525,6 +562,7 @@ class ExpressionLeaf(Leaf, ExpressionBase): ...@@ -525,6 +562,7 @@ class ExpressionLeaf(Leaf, ExpressionBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExpressionLeaf, self).__init__(*args, **kwargs) super(ExpressionLeaf, self).__init__(*args, **kwargs)
self.type = TYPE_MAP[type(args[0])] self.type = TYPE_MAP[type(args[0])]
self.parens = False
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -628,16 +666,22 @@ class Scope(object): ...@@ -628,16 +666,22 @@ class Scope(object):
def replace(self, node, replacement): def replace(self, node, replacement):
self.remove(node, replacement=replacement) self.remove(node, replacement=replacement)
#def as_nary_node(self):
def as_real_nary_node(self):
return ExpressionNode(self.node.op, *self.nodes) \
.negate(self.node.negated, clone=False)
#def as_binary_node(self):
def as_nary_node(self): def as_nary_node(self):
return nary_node(self.node.op, self.nodes).negate(self.node.negated) return nary_node(self.node.op, self.nodes) \
#return negate(nary_node(self.node.op, self.nodes), self.node.negated) .negate(self.node.negated, clone=False)
def all_except(self, node): def all_except(self, node):
before = range(0, node.scope_index) before = range(0, node.scope_index)
after = range(node.scope_index + 1, len(self)) after = range(node.scope_index + 1, len(self))
nodes = [self[i] for i in before + after] nodes = [self[i] for i in before + after]
return nary_node(self.node.op, nodes).negate(self.node.negated) return negate(nary_node(self.node.op, nodes), self.node.negated)
def nary_node(operator, scope): def nary_node(operator, scope):
...@@ -663,6 +707,14 @@ def get_scope(node): ...@@ -663,6 +707,14 @@ def get_scope(node):
else: else:
scope.append(child) scope.append(child)
#for child in node:
# if child.is_op(node.op) and (not child.negated or node.op == OP_MUL):
# sub_scope = get_scope(child)
# sub_scope[0] = sub_scope[0].negate(child.negated)
# scope += sub_scope
# else:
# scope.append(child)
return scope return scope
...@@ -716,10 +768,13 @@ def tan(*args): ...@@ -716,10 +768,13 @@ def tan(*args):
return ExpressionNode(OP_TAN, *args) return ExpressionNode(OP_TAN, *args)
def log(exponent, base=10): def log(exponent, base=None):
""" """
Create a logarithm function node (default base is 10). Create a logarithm function node (default base is 10).
""" """
if base is None:
base = DEFAULT_LOGARITHM_BASE
if not isinstance(base, ExpressionLeaf): if not isinstance(base, ExpressionLeaf):
base = ExpressionLeaf(base) base = ExpressionLeaf(base)
...@@ -737,7 +792,7 @@ def der(f, x=None): ...@@ -737,7 +792,7 @@ def der(f, x=None):
""" """
Create a derivative node. Create a derivative node.
""" """
return ExpressionNode(OP_DER, f, x) if x else ExpressionNode(OP_DER, f) return ExpressionNode(OP_DXDER, f, x) if x else ExpressionNode(OP_PRIME, f)
def integral(*args): def integral(*args):
...@@ -766,21 +821,3 @@ def sqrt(exp): ...@@ -766,21 +821,3 @@ def sqrt(exp):
Create a square root node. Create a square root node.
""" """
return ExpressionNode(OP_SQRT, exp) return ExpressionNode(OP_SQRT, exp)
def negation_to_node(node):
"""
Recursively replace negation flags inside a node by explicit unary negation
nodes.
"""
if node.negated:
negations = node.negated
node = negate(node, 0)
for i in range(negations):
node = ExpressionNode('-', node)
if node.is_leaf:
return node
return ExpressionNode(node.op, *map(negation_to_node, node))
This diff is collapsed.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from node import TYPE_OPERATOR from node import TYPE_OPERATOR, OP_MUL, Scope
import re import re
...@@ -80,6 +80,20 @@ def find_parent_node(root, child): ...@@ -80,6 +80,20 @@ def find_parent_node(root, child):
node = node[0] node = node[0]
def flatten_mult(node):
if node.is_leaf:
return node
if node.is_op(OP_MUL):
scope = Scope(node)
scope.nodes = map(flatten_mult, scope)
return scope.as_nary_node()
node.nodes = map(flatten_mult, node)
return node
def apply_suggestion(root, suggestion): def apply_suggestion(root, suggestion):
# TODO: clone the root node before modifying. After deep copying the root # TODO: clone the root node before modifying. After deep copying the root
# node, the subtree_map cannot be used since the hash() of each node in the # node, the subtree_map cannot be used since the hash() of each node in the
...@@ -100,6 +114,7 @@ def apply_suggestion(root, suggestion): ...@@ -100,6 +114,7 @@ def apply_suggestion(root, suggestion):
if parent_node: if parent_node:
parent_node.substitute(suggestion.root, subtree) parent_node.substitute(suggestion.root, subtree)
return root else:
root = subtree
return subtree return flatten_mult(root)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_NEG, OP_SIN, OP_COS, \ from ..node import OP_ADD, OP_MUL, OP_DIV, OP_POW, OP_NEG, OP_SIN, OP_COS, \
OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS, OP_SQRT, \ OP_TAN, OP_DER, OP_LOG, OP_INT, OP_INT_INDEF, OP_EQ, OP_ABS, OP_SQRT, \
OP_AND, OP_OR OP_AND, OP_OR, OP_DXDER, OP_PRIME
from .groups import match_combine_groups from .groups import match_combine_groups
from .factors import match_expand from .factors import match_expand
from .powers import match_add_exponents, match_subtract_exponents, \ from .powers import match_add_exponents, match_subtract_exponents, \
...@@ -85,3 +85,4 @@ RULES = { ...@@ -85,3 +85,4 @@ RULES = {
OP_AND: [match_multiple_equations, match_double_case], OP_AND: [match_multiple_equations, match_double_case],
OP_OR: [match_double_case], OP_OR: [match_double_case],
} }
RULES[OP_DXDER] = RULES[OP_PRIME] = RULES[OP_DER]
...@@ -188,9 +188,7 @@ def power_rule(root, args): ...@@ -188,9 +188,7 @@ def power_rule(root, args):
""" """
[f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]' [f(x) ^ g(x)]' -> [e ^ ln(f(x) ^ g(x))]'
""" """
x = second_arg(root) return der(L(E) ** ln(root[0]), second_arg(root))
return der(L(E) ** ln(root[0]), x)
MESSAGES[power_rule] = \ MESSAGES[power_rule] = \
......
...@@ -60,9 +60,9 @@ def match_expand(node): ...@@ -60,9 +60,9 @@ def match_expand(node):
def expand(root, args): def expand(root, args):
""" """
(a + b)(c + d) -> ac + ad + bc + bd
(a + b)c -> ac + bc
a(b + c) -> ab + ac a(b + c) -> ab + ac
(a + b)c -> ac + bc
(a + b)(c + d) -> ac + ad + bc + bd
etc.. etc..
""" """
scope, left, right = args scope, left, right = args
......
...@@ -226,7 +226,7 @@ def raised_base(root, args): ...@@ -226,7 +226,7 @@ def raised_base(root, args):
return args[0] return args[0]
MESSAGES[raised_base] = _('Apply `g ^ (log_(g)(a)) = a` on {0}.') MESSAGES[raised_base] = _('Apply `g ^ log_g(a) = a` to {0}.')
def match_factor_out_exponent(node): def match_factor_out_exponent(node):
......
...@@ -72,12 +72,7 @@ def add_numerics(root, args): ...@@ -72,12 +72,7 @@ def add_numerics(root, args):
-2 + -3 -> -5 -2 + -3 -> -5
""" """
scope, c0, c1 = args scope, c0, c1 = args
value = c0.actual_value() + c1.actual_value() scope.replace(c0, Leaf(c0.actual_value() + c1.actual_value()))
# Replace the left node with the new expression
scope.replace(c0, Leaf(abs(value), negated=int(value < 0)))
# Remove the right node
scope.remove(c1) scope.remove(c1)
return scope.as_nary_node() return scope.as_nary_node()
...@@ -193,13 +188,10 @@ def match_multiply_numerics(node): ...@@ -193,13 +188,10 @@ def match_multiply_numerics(node):
numerics = filter(is_numeric_node, scope) numerics = filter(is_numeric_node, scope)
for n in numerics: for n in numerics:
if n.negated:
continue
if n.value == 0: if n.value == 0:
p.append(P(node, multiply_zero, (n,))) p.append(P(node, multiply_zero, (n,)))
if n.value == 1: if not n.negated and n.value == 1:
p.append(P(node, multiply_one, (scope, n))) p.append(P(node, multiply_one, (scope, n)))
for c0, c1 in combinations(numerics, 2): for c0, c1 in combinations(numerics, 2):
......
...@@ -22,7 +22,7 @@ from .derivatives import chain_rule ...@@ -22,7 +22,7 @@ from .derivatives import chain_rule
from .negation import double_negation, negated_factor, negated_nominator, \ from .negation import double_negation, negated_factor, negated_nominator, \
negated_denominator, negated_zero negated_denominator, negated_zero
from .fractions import multiply_with_fraction, divide_fraction_by_term, \ from .fractions import multiply_with_fraction, divide_fraction_by_term, \
add_nominators add_nominators, division_by_one
from .integrals import factor_out_constant, integrate_variable_root from .integrals import factor_out_constant, integrate_variable_root
from .powers import remove_power_of_one from .powers import remove_power_of_one
from .sqrt import quadrant_sqrt, extract_sqrt_mult_priority from .sqrt import quadrant_sqrt, extract_sqrt_mult_priority
...@@ -37,6 +37,16 @@ HIGH = [ ...@@ -37,6 +37,16 @@ HIGH = [
# 4 / 4 + 1 / 4 -> 5 / 4 instead of 1 + 1/4 # 4 / 4 + 1 / 4 -> 5 / 4 instead of 1 + 1/4
add_nominators, add_nominators,
# Some operations are obvious, they are mostly done on-the-fly
multiply_zero,
multiply_one,
remove_zero,
double_negation,
division_by_one,
add_numerics,
multiply_numerics,
negated_factor,
] ]
...@@ -104,12 +114,12 @@ IMPLICIT_RULES = [ ...@@ -104,12 +114,12 @@ IMPLICIT_RULES = [
double_negation, double_negation,
negated_nominator, negated_nominator,
negated_denominator, negated_denominator,
multiply_one,
multiply_zero, multiply_zero,
multiply_one,
division_by_one,
negated_zero, negated_zero,
remove_zero, remove_zero,
remove_power_of_one, remove_power_of_one,
negated_factor,
add_numerics, add_numerics,
swap_factors, swap_factors,
] ]
...@@ -16,11 +16,10 @@ import sys ...@@ -16,11 +16,10 @@ import sys
from external.graph_drawing.graph import generate_graph from external.graph_drawing.graph import generate_graph
from external.graph_drawing.line import generate_line from external.graph_drawing.line import generate_line
from src.node import negation_to_node
def create_graph(node): def create_graph(node):
return generate_graph(negation_to_node(node)) return node.graph() if node else None
class ParserWrapper(object): class ParserWrapper(object):
......
...@@ -42,7 +42,7 @@ class TestB1Ch10(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestB1Ch10(unittest.TestCase):
) )
), ),
('-x^3*-2x^5', ('-x^3*-2x^5',
-(L('x') ** L(3) * -(L(2) * L('x') ** L(5))) -(L('x') ** L(3) * -L(2) * L('x') ** L(5))
), ),
('(7x^2y^3)^2/(7x^2y^3)', ('(7x^2y^3)^2/(7x^2y^3)',
N('/', N('/',
......
...@@ -19,45 +19,45 @@ class TestLeidenOefenopgave(TestCase): ...@@ -19,45 +19,45 @@ class TestLeidenOefenopgave(TestCase):
def test_1_1(self): def test_1_1(self):
self.assertRewrite([ self.assertRewrite([
'-5(x ^ 2 - 3x + 6)', '-5(x ^ 2 - 3x + 6)',
'-(5x ^ 2 + 5(-3x) + 5 * 6)', '-(5x ^ 2 + 5 * -3x + 5 * 6)',
'-(5x ^ 2 - 5 * 3x + 5 * 6)', '-(5x ^ 2 + (-15)x + 5 * 6)',
'-(5x ^ 2 - 15x + 5 * 6)', '-(5x ^ 2 + (-15)x + 30)',
'-(5x ^ 2 - 15x + 30)', '-(5x ^ 2 - 15x + 30)',
'-5x ^ 2 - -15x - 30', '-5x ^ 2 - -15x - 30',
'-5x ^ 2 + 15x - 30', '-5x ^ 2 + 15x - 30',
]) ])
return
#for exp, solution in [
for exp, solution in [ # ('-5(x^2 - 3x + 6)', '-30 + 15x - 5x ^ 2'),
('-5(x^2 - 3x + 6)', '-30 + 15x - 5x ^ 2'), # ('(x+1)^2', 'x ^ 2 + 2x + 1'),
('(x+1)^2', 'x ^ 2 + 2x + 1'), # ('(x-1)^2', 'x ^ 2 - 2x + 1'),
('(x-1)^2', 'x ^ 2 - 2x + 1'), # ('(2x+x)*x', '3x ^ 2'),
('(2x+x)*x', '3x ^ 2'), # ('-2(6x-4)^2*x', '-72x ^ 3 + 96x ^ 2 + 32x'),
('-2(6x-4)^2*x', '-72x ^ 3 + 96x ^ 2 + 32x'), # ('(4x + 5) * -(5 - 4x)', '16x^2 - 25'),
('(4x + 5) * -(5 - 4x)', '16x^2 - 25'), # ]:
]: # self.assertEqual(str(rewrite(exp)), solution)
self.assertEqual(str(rewrite(exp)), solution)
def test_1_2(self): def test_1_2(self):
self.assertRewrite([ self.assertRewrite([
'(x+1)^3', '(x + 1)(x + 1) ^ 2', '(x + 1) ^ 3',
'(x + 1)(x + 1) ^ 2',
'(x + 1)(x + 1)(x + 1)', '(x + 1)(x + 1)(x + 1)',
'(xx + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x * 1 + 1x + 1 * 1)(x + 1)',
'(x ^ (1 + 1) + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x + 1x + 1 * 1)(x + 1)',
'(x ^ 2 + x * 1 + 1x + 1 * 1)(x + 1)', '(xx + x + x + 1 * 1)(x + 1)',
'(x ^ 2 + x + 1x + 1 * 1)(x + 1)', '(xx + x + x + 1)(x + 1)',
'(x ^ 2 + x + x + 1 * 1)(x + 1)', '(x ^ (1 + 1) + x + x + 1)(x + 1)',
'(x ^ 2 + x + x + 1)(x + 1)', '(x ^ 2 + x + x + 1)(x + 1)',
'(x ^ 2 + (1 + 1)x + 1)(x + 1)', '(x ^ 2 + (1 + 1)x + 1)(x + 1)',
'(x ^ 2 + 2x + 1)(x + 1)', '(x ^ 2 + 2x + 1)(x + 1)',
'x ^ 2 * x + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1',
'x ^ (2 + 1) + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x * 1 + 1x + 1 * 1',
'x ^ 3 + x ^ 2 * 1 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + 1x + 1 * 1',
'x ^ 3 + x ^ 2 + 2xx + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + x + 1 * 1',
'x ^ 3 + x ^ 2 + 2x ^ (1 + 1) + 2x * 1 + 1x + 1 * 1', 'x ^ 2 * x + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x * 1 + 1x + 1 * 1', 'x ^ (2 + 1) + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + 1x + 1 * 1', 'x ^ 3 + x ^ 2 + 2xx + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1 * 1', 'x ^ 3 + x ^ 2 + 2x ^ (1 + 1) + 2x + x + 1',
'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1', 'x ^ 3 + x ^ 2 + 2x ^ 2 + 2x + x + 1',
'x ^ 3 + (1 + 2)x ^ 2 + 2x + x + 1', 'x ^ 3 + (1 + 2)x ^ 2 + 2x + x + 1',
'x ^ 3 + 3x ^ 2 + 2x + x + 1', 'x ^ 3 + 3x ^ 2 + 2x + x + 1',
...@@ -68,12 +68,13 @@ class TestLeidenOefenopgave(TestCase): ...@@ -68,12 +68,13 @@ class TestLeidenOefenopgave(TestCase):
def test_1_3(self): def test_1_3(self):
# (x+1)^2 -> x^2 + 2x + 1 # (x+1)^2 -> x^2 + 2x + 1
self.assertRewrite([ self.assertRewrite([
'(x+1)^2', '(x + 1)(x + 1)', '(x + 1) ^ 2',
'(x + 1)(x + 1)',
'xx + x * 1 + 1x + 1 * 1', 'xx + x * 1 + 1x + 1 * 1',
'x ^ (1 + 1) + x * 1 + 1x + 1 * 1', 'xx + x + 1x + 1 * 1',
'x ^ 2 + x * 1 + 1x + 1 * 1', 'xx + x + x + 1 * 1',
'x ^ 2 + x + 1x + 1 * 1', 'xx + x + x + 1',
'x ^ 2 + x + x + 1 * 1', 'x ^ (1 + 1) + x + x + 1',
'x ^ 2 + x + x + 1', 'x ^ 2 + x + x + 1',
'x ^ 2 + (1 + 1)x + 1', 'x ^ 2 + (1 + 1)x + 1',
'x ^ 2 + 2x + 1', 'x ^ 2 + 2x + 1',
...@@ -84,25 +85,25 @@ class TestLeidenOefenopgave(TestCase): ...@@ -84,25 +85,25 @@ class TestLeidenOefenopgave(TestCase):
self.assertRewrite([ self.assertRewrite([
'(x - 1) ^ 2', '(x - 1) ^ 2',
'(x - 1)(x - 1)', '(x - 1)(x - 1)',
'xx + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x + (-1) * -1',
'x ^ (1 + 1) + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x - -1',
'x ^ 2 + x(-1) + (-1)x + (-1)(-1)', 'xx + x * -1 + (-1)x + 1',
'x ^ 2 - x * 1 + (-1)x + (-1)(-1)', 'xx - x * 1 + (-1)x + 1',
'x ^ 2 - x + (-1)x + (-1)(-1)', 'xx - x + (-1)x + 1',
'x ^ 2 - x - 1x + (-1)(-1)', 'xx - x - 1x + 1',
'x ^ 2 - x - x + (-1)(-1)', 'xx - x - x + 1',
'x ^ 2 - x - x - -1', 'x ^ (1 + 1) - x - x + 1',
'x ^ 2 - x - x + 1', 'x ^ 2 - x - x + 1',
'x ^ 2 + (1 + 1)(-x) + 1', 'x ^ 2 + (1 + 1) * -x + 1',
'x ^ 2 + 2(-x) + 1', 'x ^ 2 + 2 * -x + 1',
'x ^ 2 - 2x + 1', 'x ^ 2 - 2x + 1',
]) ])
def test_1_4_1(self): def test_1_4_1(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 + 1x', 'x * -1 + 1x',
'-x * 1 + 1x', 'x * -1 + x',
'-x + 1x', '-x * 1 + x',
'-x + x', '-x + x',
'(-1 + 1)x', '(-1 + 1)x',
'0x', '0x',
...@@ -112,23 +113,23 @@ class TestLeidenOefenopgave(TestCase): ...@@ -112,23 +113,23 @@ class TestLeidenOefenopgave(TestCase):
def test_1_4_2(self): def test_1_4_2(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 - 1x', 'x * -1 - 1x',
'-x * 1 - 1x', 'x * -1 - x',
'-x - 1x', '-x * 1 - x',
'-x - x', '-x - x',
'(1 + 1)(-x)', '(1 + 1) * -x',
'2(-x)', '2 * -x',
'-2x', '-2x',
]) ])
def test_1_4_3(self): def test_1_4_3(self):
self.assertRewrite([ self.assertRewrite([
'x * -1 + x * -1', 'x * -1 + x * -1',
'-x * 1 + x(-1)', '-x * 1 + x * -1',
'-x + x(-1)', '-x + x * -1',
'-x - x * 1', '-x - x * 1',
'-x - x', '-x - x',
'(1 + 1)(-x)', '(1 + 1) * -x',
'2(-x)', '2 * -x',
'-2x', '-2x',
]) ])
...@@ -144,19 +145,21 @@ class TestLeidenOefenopgave(TestCase): ...@@ -144,19 +145,21 @@ class TestLeidenOefenopgave(TestCase):
def test_1_7(self): def test_1_7(self):
self.assertRewrite([ self.assertRewrite([
'(4x + 5) * -(5 - 4x)', '(4x + 5) * -(5 - 4x)',
'(4x + 5)(-5 - -4x)', '-(4x + 5)(5 - 4x)',
'(4x + 5)(-5 + 4x)', '-(4x * 5 + 4x * -4x + 5 * 5 + 5 * -4x)',
'4x(-5) + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + 4x * -4x + 5 * 5 + 5 * -4x)',
'(-20)x + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 5 * 5 + 5 * -4x)',
'-20x + 4x * 4x + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 25 + 5 * -4x)',
'-20x + 16xx + 5(-5) + 5 * 4x', '-(20x + (-16)xx + 25 + (-20)x)',
'-20x + 16x ^ (1 + 1) + 5(-5) + 5 * 4x', '-(20x - 16xx + 25 + (-20)x)',
'-20x + 16x ^ 2 + 5(-5) + 5 * 4x', '-(20x - 16xx + 25 - 20x)',
'-20x + 16x ^ 2 - 25 + 5 * 4x', '-(20x - 16x ^ (1 + 1) + 25 - 20x)',
'-20x + 16x ^ 2 - 25 + 20x', '-(20x - 16x ^ 2 + 25 - 20x)',
'(-1 + 1)20x + 16x ^ 2 - 25', '-((1 - 1)20x - 16x ^ 2 + 25)',
'0 * 20x + 16x ^ 2 - 25', '-(0 * 20x - 16x ^ 2 + 25)',
'0 + 16x ^ 2 - 25', '-(0 - 16x ^ 2 + 25)',
'-(-16x ^ 2 + 25)',
'--16x ^ 2 - 25',
'16x ^ 2 - 25', '16x ^ 2 - 25',
]) ])
...@@ -176,7 +179,7 @@ class TestLeidenOefenopgave(TestCase): ...@@ -176,7 +179,7 @@ class TestLeidenOefenopgave(TestCase):
def test_4_2(self): def test_4_2(self):
self.assertRewrite([ self.assertRewrite([
'2/7 - 4/11', '2 / 7 - 4 / 11',
'22 / 77 - 28 / 77', '22 / 77 - 28 / 77',
'(22 - 28) / 77', '(22 - 28) / 77',
'(-6) / 77', '(-6) / 77',
......
...@@ -16,17 +16,6 @@ from tests.rulestestcase import RulesTestCase as TestCase ...@@ -16,17 +16,6 @@ from tests.rulestestcase import RulesTestCase as TestCase
class TestLeidenOefenopgaveV12(TestCase): class TestLeidenOefenopgaveV12(TestCase):
def test_1_a(self):
self.assertRewrite([
'-5(x^2 - 3x + 6)',
'-(5x ^ 2 + 5(-3x) + 5 * 6)',
'-(5x ^ 2 - 5 * 3x + 5 * 6)',
'-(5x ^ 2 - 15x + 5 * 6)',
'-(5x ^ 2 - 15x + 30)',
'-5x ^ 2 - -15x - 30',
'-5x ^ 2 + 15x - 30',
])
def test_1_d(self): def test_1_d(self):
self.assertRewrite([ self.assertRewrite([
'(2x + x)x', '(2x + x)x',
...@@ -40,30 +29,28 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -40,30 +29,28 @@ class TestLeidenOefenopgaveV12(TestCase):
self.assertRewrite([ self.assertRewrite([
'-2(6x - 4) ^ 2x', '-2(6x - 4) ^ 2x',
'-2(6x - 4)(6x - 4)x', '-2(6x - 4)(6x - 4)x',
'-(2 * 6x + 2(-4))(6x - 4)x', '-(2 * 6x + 2 * -4)(6x - 4)x',
'-(12x + 2(-4))(6x - 4)x', '-(12x + 2 * -4)(6x - 4)x',
'-(12x - 8)(6x - 4)x', '-(12x - 8)(6x - 4)x',
'-(12x - 8)(6xx + (-4)x)', '-(12x - 8)(6xx + (-4)x)',
'-(12x - 8)(6x ^ (1 + 1) + (-4)x)', '-(12x - 8)(6xx - 4x)',
'-(12x - 8)(6x ^ 2 + (-4)x)', '-(12x - 8)(6x ^ (1 + 1) - 4x)',
'-(12x - 8)(6x ^ 2 - 4x)', '-(12x - 8)(6x ^ 2 - 4x)',
'-(12x * 6x ^ 2 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(12x * 6x ^ 2 + 12x * -4x + (-8)6x ^ 2 + (-8) * -4x)',
'-(72xx ^ 2 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + 12x * -4x + (-8)6x ^ 2 + (-8) * -4x)',
'-(72x ^ (1 + 2) + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-8)6x ^ 2 + (-8) * -4x)',
'-(72x ^ 3 + 12x(-4x) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + (-8) * -4x)',
'-(72x ^ 3 - 12x * 4x + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + (--32)x)',
'-(72x ^ 3 - 48xx + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 + (-48)xx + (-48)x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ (1 + 1) + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 - 48xx + (-48)x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 + (-8)6x ^ 2 + (-8)(-4x))', '-(72x x ^ 2 - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 + (-48)x ^ 2 + (-8)(-4x))', '-(72x ^ (1 + 2) - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + (-8)(-4x))', '-(72x ^ 3 - 48xx - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - 8(-4x))', '-(72x ^ 3 - 48x ^ (1 + 1) - 48x ^ 2 + 32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - -8 * 4x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 - -32x)',
'-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + 32x)', '-(72x ^ 3 - 48x ^ 2 - 48x ^ 2 + 32x)',
'-(72x ^ 3 + (1 + 1)(-48x ^ 2) + 32x)', '-(72x ^ 3 + (1 + 1) * -48x ^ 2 + 32x)',
'-(72x ^ 3 + 2(-48x ^ 2) + 32x)', '-(72x ^ 3 + 2 * -48x ^ 2 + 32x)',
'-(72x ^ 3 - 2 * 48x ^ 2 + 32x)', '-(72x ^ 3 + (-96)x ^ 2 + 32x)',
'-(72x ^ 3 - 96x ^ 2 + 32x)', '-(72x ^ 3 - 96x ^ 2 + 32x)',
'-72x ^ 3 - -96x ^ 2 - 32x', '-72x ^ 3 - -96x ^ 2 - 32x',
'-72x ^ 3 + 96x ^ 2 - 32x', '-72x ^ 3 + 96x ^ 2 - 32x',
...@@ -71,23 +58,23 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -71,23 +58,23 @@ class TestLeidenOefenopgaveV12(TestCase):
def test_2_a(self): def test_2_a(self):
self.assertRewrite([ self.assertRewrite([
'(a ^ 2 * b ^ -1) ^ 3(ab ^ 2)', '(a ^ 2 * b ^ -1) ^ 3 * a b ^ 2',
'(a ^ 2 * 1 / b ^ 1) ^ 3 * ab ^ 2', '(a ^ 2 * 1 / b ^ 1) ^ 3 * a b ^ 2',
'(a ^ 2 * 1 / b) ^ 3 * ab ^ 2', '(a ^ 2 * 1 / b) ^ 3 * a b ^ 2',
'((a ^ 2 * 1) / b) ^ 3 * ab ^ 2', '((a ^ 2 * 1) / b) ^ 3 * a b ^ 2',
'(a ^ 2 / b) ^ 3 * ab ^ 2', '(a ^ 2 / b) ^ 3 * a b ^ 2',
'(a ^ 2) ^ 3 / b ^ 3 * ab ^ 2', '(a ^ 2) ^ 3 / b ^ 3 * a b ^ 2',
'a ^ (2 * 3) / b ^ 3 * ab ^ 2', 'a ^ (2 * 3) / b ^ 3 * a b ^ 2',
'a ^ 6 / b ^ 3 * ab ^ 2', 'a ^ 6 / b ^ 3 * a b ^ 2',
'(a ^ 6 * a) / b ^ 3 * b ^ 2', '(a ^ 6 * a) / b ^ 3 * b ^ 2',
'a ^ (6 + 1) / b ^ 3 * b ^ 2', 'a ^ (6 + 1) / b ^ 3 * b ^ 2',
'a ^ 7 / b ^ 3 * b ^ 2', 'a ^ 7 / b ^ 3 * b ^ 2',
'(a ^ 7 * b ^ 2) / b ^ 3', '(a ^ 7 * b ^ 2) / b ^ 3',
'b ^ 2 / b ^ 3 * a ^ 7 / 1', 'b ^ 2 / b ^ 3 * a ^ 7 / 1',
'b ^ (2 - 3)a ^ 7 / 1', 'b ^ 2 / b ^ 3 * a ^ 7',
'b ^ (-1)a ^ 7 / 1', 'b ^ (2 - 3)a ^ 7',
'1 / b ^ 1 * a ^ 7 / 1', 'b ^ -1 * a ^ 7',
'1 / b * a ^ 7 / 1', '1 / b ^ 1 * a ^ 7',
'1 / b * a ^ 7', '1 / b * a ^ 7',
'(1a ^ 7) / b', '(1a ^ 7) / b',
'a ^ 7 / b', 'a ^ 7 / b',
...@@ -124,9 +111,9 @@ class TestLeidenOefenopgaveV12(TestCase): ...@@ -124,9 +111,9 @@ class TestLeidenOefenopgaveV12(TestCase):
def test_2_f(self): def test_2_f(self):
self.assertRewrite([ self.assertRewrite([
'(4b) ^ -2', '(4b) ^ -2',
'4 ^ (-2)b ^ (-2)', '4 ^ -2 * b ^ -2',
'1 / 4 ^ 2 * b ^ (-2)', '1 / 4 ^ 2 * b ^ -2',
'1 / 16 * b ^ (-2)', '1 / 16 * b ^ -2',
'1 / 16 * 1 / b ^ 2', '1 / 16 * 1 / b ^ 2',
'(1 * 1) / (16b ^ 2)', '(1 * 1) / (16b ^ 2)',
'1 / (16b ^ 2)', '1 / (16b ^ 2)',
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \ from src.node import ExpressionNode as N, ExpressionLeaf as L, Scope, \
nary_node, get_scope, OP_ADD, infinity, absolute, sin, cos, tan, log, \ nary_node, get_scope, OP_ADD, infinity, absolute, sin, cos, tan, log, \
ln, der, integral, indef, eq, negation_to_node ln, der, integral, indef, eq
from tests.rulestestcase import RulesTestCase, tree from tests.rulestestcase import RulesTestCase, tree
...@@ -121,8 +121,8 @@ class TestNode(RulesTestCase): ...@@ -121,8 +121,8 @@ class TestNode(RulesTestCase):
self.assertEqual(get_scope(plus), self.l) self.assertEqual(get_scope(plus), self.l)
def test_get_scope_negation(self): def test_get_scope_negation(self):
root, a, b, cd = tree('a * b * -cd, a, b, -cd') root, a, b, c, d = tree('ab * -cd, a, b, -c, d')
self.assertEqual(get_scope(root), [a, b, cd]) self.assertEqual(get_scope(root), [a, b, c, d])
def test_get_scope_index(self): def test_get_scope_index(self):
self.assertEqual(self.scope.index(self.a), 0) self.assertEqual(self.scope.index(self.a), 0)
...@@ -244,15 +244,15 @@ class TestNode(RulesTestCase): ...@@ -244,15 +244,15 @@ class TestNode(RulesTestCase):
self.assertTrue(ma.contains(a)) self.assertTrue(ma.contains(a))
def test_construct_function_derivative(self): def test_construct_function_derivative(self):
self.assertEqual(str(tree('der(x ^ 2)')), '[x ^ 2]\'') self.assertEqual(str(tree("(x ^ 2)'")), "[x ^ 2]'")
self.assertEqual(str(tree('der(der(x ^ 2))')), '[x ^ 2]\'\'') self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''")
self.assertEqual(str(tree('der(x ^ 2, x)')), 'd/dx (x ^ 2)') self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2')
def test_construct_function_logarithm(self): def test_construct_function_logarithm(self):
self.assertEqual(str(tree('log(x, e)')), 'ln(x)') self.assertEqual(str(tree('log(x, e)')), 'ln x')
self.assertEqual(str(tree('log(x, 10)')), 'log(x)') self.assertEqual(str(tree('log(x, 10)')), 'log x')
self.assertEqual(str(tree('log(x, 2)')), 'log_2(x)') self.assertEqual(str(tree('log(x, 2)')), 'log_2 x')
self.assertEqual(str(tree('log(x, g)')), 'log(x, g)') self.assertEqual(str(tree('log(x, g)')), 'log_g x')
def test_construct_function_integral(self): def test_construct_function_integral(self):
self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx') self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx')
...@@ -318,9 +318,3 @@ class TestNode(RulesTestCase): ...@@ -318,9 +318,3 @@ class TestNode(RulesTestCase):
def test_eq(self): def test_eq(self):
x, a, b, expect = tree('x, a, b, x + a = b') x, a, b, expect = tree('x, a, b, x + a = b')
self.assertEqual(eq(x + a, b), expect) self.assertEqual(eq(x + a, b), expect)
def test_negation_to_node(self):
a = tree('a')
self.assertEqual(negation_to_node(-a), N('-', a))
self.assertEqual(negation_to_node(-(a + 1)), N('-', a + 1))
self.assertEqual(negation_to_node(-(a - 1)), N('-', a + N('-', L(1))))
...@@ -89,12 +89,13 @@ class TestParser(RulesTestCase): ...@@ -89,12 +89,13 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('sin x'), sin(x)) self.assertEqual(tree('sin x'), sin(x))
self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct? self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct?
self.assertEqual(tree('sin x ^ 2'), sin(x ** 2)) self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2) self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2)) self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
self.assertEqual(tree('sin cos x'), sin(cos(x))) self.assertEqual(tree('sin cos x'), sin(cos(x)))
self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2)) self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
def test_brackets(self): def test_brackets(self):
self.assertEqual(*tree('[x], x')) self.assertEqual(*tree('[x], x'))
...@@ -145,7 +146,7 @@ class TestParser(RulesTestCase): ...@@ -145,7 +146,7 @@ class TestParser(RulesTestCase):
# FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a) # FIXME: self.assertEqual(tree('a' + token + 'a'), a * t * a)
def test_integral(self): def test_integral(self):
x, y, dx, a, b, l2 = tree('x, y, dx, a, b, 2') x, y, dx, a, b, l2, oo = tree('x, y, dx, a, b, 2, oo')
self.assertEqual(tree('int x'), integral(x, x)) self.assertEqual(tree('int x'), integral(x, x))
self.assertEqual(tree('int x ^ 2'), integral(x ** 2, x)) self.assertEqual(tree('int x ^ 2'), integral(x ** 2, x))
...@@ -162,21 +163,18 @@ class TestParser(RulesTestCase): ...@@ -162,21 +163,18 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('int_a^(b2) x'), integral(x, x, a, b * 2)) self.assertEqual(tree('int_a^(b2) x'), integral(x, x, a, b * 2))
self.assertEqual(tree('int x ^ 2 + 1'), integral(x ** 2, x) + 1) self.assertEqual(tree('int x ^ 2 + 1'), integral(x ** 2, x) + 1)
self.assertEqual(tree('int x ^ 2 + 1 dx'), integral(x ** 2 + 1, x))
self.assertEqual(tree('int_a^b x ^ 2 dx'), integral(x ** 2, x, a, b)) self.assertEqual(tree('int_a^b x ^ 2 dx'), integral(x ** 2, x, a, b))
self.assertEqual(tree('int_a^(b2) x ^ 2 + 1 dx'), self.assertEqual(tree('int_a x ^ 2 dx'), integral(x ** 2, x, a, oo))
integral(x ** 2 + 1, x, a, b * 2))
self.assertEqual(tree('int_(a^2)^b x ^ 2 + 1 dx'),
integral(x ** 2 + 1, x, a ** 2, b))
self.assertEqual(tree('int_(-a)^b x dx'), integral(x, x, -a, b)) self.assertEqual(tree('int_(-a)^b x dx'), integral(x, x, -a, b))
# FIXME: self.assertEqual(tree('int_-a^b x dx'), integral(x, x, -a, b)) #self.assertEqual(tree('int_-a^b x dx'), integral(x, x, -a, b))
def test_indefinite_integral(self): def test_indefinite_integral(self):
x2, a, b = tree('x ^ 2, a, b') x2, a, b, oo = tree('x ^ 2, a, b, oo')
self.assertEqual(tree('[x ^ 2]_a^b'), indef(x2, a, b)) self.assertEqual(tree('(x ^ 2)_a'), indef(x2, a, oo))
self.assertEqual(tree('(x ^ 2)_a^b'), indef(x2, a, b))
def test_absolute_value(self): def test_absolute_value(self):
x = tree('x') x = tree('x')
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# along with TRS. If not, see <http://www.gnu.org/licenses/>. # along with TRS. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from src.possibilities import MESSAGES, Possibility as P from src.possibilities import MESSAGES, Possibility as P, flatten_mult
from tests.rulestestcase import tree from tests.rulestestcase import tree
from src.parser import Parser from src.parser import Parser
...@@ -80,23 +80,6 @@ class TestPossibilities(unittest.TestCase): ...@@ -80,23 +80,6 @@ class TestPossibilities(unittest.TestCase):
'<Possibility root="3 + 4" handler=add_numerics' \ '<Possibility root="3 + 4" handler=add_numerics' \
' args=(<Scope of "3 + 4">, 3, 4)>') ' args=(<Scope of "3 + 4">, 3, 4)>')
#def test_filter_duplicates(self): def test_flatten_mult(self):
# a, b = ab = tree('a + b') self.assertEqual(flatten_mult(tree('2(xx)')), tree('2xx'))
# p0 = P(a, dummy_handler, (1, 2)) self.assertEqual(flatten_mult(tree('2(xx) + 1')), tree('2xx + 1'))
# p1 = P(ab, dummy_handler, (1, 2))
# p2 = P(ab, dummy_handler, (1, 2, 3))
# p3 = P(ab, dummy_handler_msg, (1, 2))
# self.assertEqual(filter_duplicates([]), [])
# self.assertEqual(filter_duplicates([p0, p1]), [p1])
# self.assertEqual(filter_duplicates([p1, p2]), [p1, p2])
# self.assertEqual(filter_duplicates([p1, p3]), [p1, p3])
# self.assertEqual(filter_duplicates([p0, p1, p2, p3]), [p1, p2, p3])
# # Docstrings example
# (l1, l2), l3 = left, l3 = right = tree('1 + 2 + 3')
# p0 = P(left, add_numerics, (1, 2, 1, 2))
# p1 = P(right, add_numerics, (1, 2, 1, 2))
# p2 = P(right, add_numerics, (1, 3, 1, 3))
# p3 = P(right, add_numerics, (2, 3, 2, 3))
# self.assertEqual(filter_duplicates([p0, p1, p2, p3]), [p1, p2, p3])
This diff is collapsed.
...@@ -320,17 +320,17 @@ class TestRulesFractions(RulesTestCase): ...@@ -320,17 +320,17 @@ class TestRulesFractions(RulesTestCase):
'1 / (1 / b - 1 / a)', '1 / (1 / b - 1 / a)',
'(b * 1) / (b(1 / b - 1 / a))', '(b * 1) / (b(1 / b - 1 / a))',
'b / (b(1 / b - 1 / a))', 'b / (b(1 / b - 1 / a))',
'b / (b * 1 / b + b(-1 / a))', 'b / (b * 1 / b + b * -1 / a)',
'b / ((b * 1) / b + b(-1 / a))', 'b / (b * 1 / b - b * 1 / a)',
'b / (b / b + b(-1 / a))', 'b / ((b * 1) / b - b * 1 / a)',
'b / (1 + b(-1 / a))', 'b / (b / b - b * 1 / a)',
'b / (1 - b * 1 / a)', 'b / (1 - b * 1 / a)',
'b / (1 - (b * 1) / a)', 'b / (1 - (b * 1) / a)',
'b / (1 - b / a)', 'b / (1 - b / a)',
'(ab) / (a(1 - b / a))', '(ab) / (a(1 - b / a))',
'(ab) / (a * 1 + a(-b / a))', '(ab) / (a * 1 + a * -b / a)',
'(ab) / (a + a(-b / a))', '(ab) / (a + a * -b / a)',
'(ab) / (a - ab / a)', '(ab) / (a - a b / a)',
'(ab) / (a - (ab) / a)', '(ab) / (a - (ab) / a)',
'(ab) / (a - b)', '(ab) / (a - b)',
]) ])
...@@ -30,30 +30,30 @@ class TestRulesGoniometry(RulesTestCase): ...@@ -30,30 +30,30 @@ class TestRulesGoniometry(RulesTestCase):
self.assertEqual(doctest.testmod(m=goniometry)[0], 0) self.assertEqual(doctest.testmod(m=goniometry)[0], 0)
def test_match_add_quadrants(self): def test_match_add_quadrants(self):
s, c = root = tree('sin(t) ^ 2 + cos(t) ^ 2') s, c = root = tree('sin^2 t + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
c, s = root = tree('cos(t) ^ 2 + sin(t) ^ 2') c, s = root = tree('cos^2 t + sin^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
(s, a), c = root = tree('sin(t) ^ 2 + a + cos(t) ^ 2') (s, a), c = root = tree('sin^2 t + a + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c))]) [P(root, add_quadrants, (Scope(root), s, c))])
(s, c0), c1 = root = tree('sin(t) ^ 2 + cos(t) ^ 2 + cos(t) ^ 2') (s, c0), c1 = root = tree('sin^2 t + cos^2 t + cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, add_quadrants, (Scope(root), s, c0)), [P(root, add_quadrants, (Scope(root), s, c0)),
P(root, add_quadrants, (Scope(root), s, c1))]) P(root, add_quadrants, (Scope(root), s, c1))])
root = tree('sin(t) ^ 2 + cos(y) ^ 2') root = tree('sin^2 t + cos^2 y')
self.assertEqualPos(match_add_quadrants(root), []) self.assertEqualPos(match_add_quadrants(root), [])
root = tree('sin(t) ^ 2 - cos(t) ^ 2') root = tree('sin^2 t - cos^2 t')
self.assertEqualPos(match_add_quadrants(root), []) self.assertEqualPos(match_add_quadrants(root), [])
s, c = root = tree('-sin(t) ^ 2 - cos(t) ^ 2') s, c = root = tree('-sin^2 t - cos^2 t')
self.assertEqualPos(match_add_quadrants(root), self.assertEqualPos(match_add_quadrants(root),
[P(root, factor_out_quadrant_negation, (Scope(root), s, c))]) [P(root, factor_out_quadrant_negation, (Scope(root), s, c))])
......
...@@ -141,9 +141,9 @@ class TestRulesIntegrals(RulesTestCase): ...@@ -141,9 +141,9 @@ class TestRulesIntegrals(RulesTestCase):
self.assertRewrite([ self.assertRewrite([
'int a / x', 'int a / x',
'int a * 1 / x dx', 'int a * 1 / x dx',
'aint 1 / x dx', 'a(int 1 / x dx)',
'a(ln|x| + C)', 'a(ln|x| + C)',
'aln|x| + aC', 'a ln|x| + aC',
# FIXME: 'aln|x| + C', # ac -> C # FIXME: 'aln|x| + C', # ac -> C
]) ])
...@@ -176,24 +176,24 @@ class TestRulesIntegrals(RulesTestCase): ...@@ -176,24 +176,24 @@ class TestRulesIntegrals(RulesTestCase):
self.assertEqual(cosinus_integral(root, ()), expect) self.assertEqual(cosinus_integral(root, ()), expect)
def test_match_sum_rule_integral(self): def test_match_sum_rule_integral(self):
(f, g), x = root = tree('int 2x + 3x dx') (f, g), x = root = tree('int (2x + 3x) dx')
self.assertEqualPos(match_sum_rule_integral(root), self.assertEqualPos(match_sum_rule_integral(root),
[P(root, sum_rule_integral, (Scope(root[0]), f))]) [P(root, sum_rule_integral, (Scope(root[0]), f))])
((f, g), h), x = root = tree('int 2x + 3x + 4x dx') ((f, g), h), x = root = tree('int (2x + 3x + 4x) dx')
self.assertEqualPos(match_sum_rule_integral(root), self.assertEqualPos(match_sum_rule_integral(root),
[P(root, sum_rule_integral, (Scope(root[0]), f)), [P(root, sum_rule_integral, (Scope(root[0]), f)),
P(root, sum_rule_integral, (Scope(root[0]), g)), P(root, sum_rule_integral, (Scope(root[0]), g)),
P(root, sum_rule_integral, (Scope(root[0]), h))]) P(root, sum_rule_integral, (Scope(root[0]), h))])
def test_sum_rule_integral(self): def test_sum_rule_integral(self):
((f, g), h), x = root = tree('int 2x + 3x + 4x dx') ((f, g), h), x = root = tree('int (2x + 3x + 4x) dx')
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), f)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), f)),
tree('int 2x dx + int 3x + 4x dx')) tree('int 2x dx + int (3x + 4x) dx'))
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), g)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), g)),
tree('int 3x dx + int 2x + 4x dx')) tree('int 3x dx + int (2x + 4x) dx'))
self.assertEqual(sum_rule_integral(root, (Scope(root[0]), h)), self.assertEqual(sum_rule_integral(root, (Scope(root[0]), h)),
tree('int 4x dx + int 2x + 3x dx')) tree('int 4x dx + int (2x + 3x) dx'))
def test_match_remove_indef_constant(self): def test_match_remove_indef_constant(self):
Fx, a, b = root = tree('[2x + C]_a^b') Fx, a, b = root = tree('[2x + C]_a^b')
......
...@@ -98,8 +98,8 @@ class TestRulesLineq(RulesTestCase): ...@@ -98,8 +98,8 @@ class TestRulesLineq(RulesTestCase):
'2x = -3x - 5', '2x = -3x - 5',
'2x - -3x = -3x - 5 - -3x', '2x - -3x = -3x - 5 - -3x',
'2x + 3x = -3x - 5 - -3x', '2x + 3x = -3x - 5 - -3x',
'(2 + 3)x = -3x - 5 - -3x', '2x + 3x = -3x - 5 + 3x',
'5x = -3x - 5 - -3x', '(2 + 3)x = -3x - 5 + 3x',
'5x = -3x - 5 + 3x', '5x = -3x - 5 + 3x',
'5x = (-1 + 1)3x - 5', '5x = (-1 + 1)3x - 5',
'5x = 0 * 3x - 5', '5x = 0 * 3x - 5',
...@@ -116,11 +116,11 @@ class TestRulesLineq(RulesTestCase): ...@@ -116,11 +116,11 @@ class TestRulesLineq(RulesTestCase):
def test_match_move_term_chain_advanced(self): def test_match_move_term_chain_advanced(self):
self.assertRewrite([ self.assertRewrite([
'-x = a', '-x = a',
'(-x)(-1) = a(-1)', '(-x) * -1 = a * -1',
'-x(-1) = a(-1)', '-x * -1 = a * -1',
'--x * 1 = a(-1)', '--x * 1 = a * -1',
'--x = a(-1)', '--x = a * -1',
'x = a(-1)', 'x = a * -1',
'x = -a * 1', 'x = -a * 1',
'x = -a', 'x = -a',
]) ])
......
...@@ -106,8 +106,8 @@ class TestRulesNegation(RulesTestCase): ...@@ -106,8 +106,8 @@ class TestRulesNegation(RulesTestCase):
def test_double_negated_division(self): def test_double_negated_division(self):
self.assertRewrite([ self.assertRewrite([
'(-a) / (-b)', '(-a) / -b',
'-a / (-b)', '-a / -b',
'--a / b', '--a / b',
'a / b', 'a / b',
]) ])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment