Commit e9dd48f8 authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented correct parentheses parsing for logarithms

parent bbdcbe9c
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os.path import os.path
import sys import sys
import copy import copy
import re
sys.path.insert(0, os.path.realpath('external')) sys.path.insert(0, os.path.realpath('external'))
...@@ -324,22 +325,33 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -324,22 +325,33 @@ class ExpressionNode(Node, ExpressionBase):
if self.op in UNARY_FUNCTIONS: if self.op in UNARY_FUNCTIONS:
return 1 return 1
if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE): #if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE):
return 1 # return 1
# Functions always have parentheses
if self.op in TOKEN_MAP and TOKEN_MAP[self.op] == 'FUNCTION':
return 2
return len(self) return len(self)
def operator(self): def operator(self):
# Append an opening parenthesis manually, the closing parentheses is
# appended by postprocess_str
if self.op == OP_LOG: if self.op == OP_LOG:
base = self[1].value base = self[1].value
if base == DEFAULT_LOGARITHM_BASE: if base == DEFAULT_LOGARITHM_BASE:
return self.value return self.value + '('
if base == E: if base == E:
return 'ln' return 'ln('
base = str(self[1])
return '%s_%s' % (self.value, str(self[1])) if not re.match('^[0-9]+|[a-zA-Z]$', base):
base = '(' + base + ')'
return '%s_%s(' % (self.value, base)
if self.op == OP_DXDER: if self.op == OP_DXDER:
return self.value + str(self[1]) return self.value + str(self[1])
...@@ -365,6 +377,10 @@ class ExpressionNode(Node, ExpressionBase): ...@@ -365,6 +377,10 @@ class ExpressionNode(Node, ExpressionBase):
self[0] = ExpressionNode(OP_BRACKETS, self[0]) self[0] = ExpressionNode(OP_BRACKETS, self[0])
def postprocess_str(self, s): def postprocess_str(self, s):
# A bit hacky, but forced because of operator() method
if self.op == OP_LOG:
return s.replace('( ', '(') + ')'
if self.op == OP_INT: if self.op == OP_INT:
return '%s d%s' % (s, self[1]) return '%s d%s' % (s, self[1])
......
...@@ -248,12 +248,6 @@ class TestNode(RulesTestCase): ...@@ -248,12 +248,6 @@ class TestNode(RulesTestCase):
self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''") self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''")
self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2') self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2')
def test_construct_function_logarithm(self):
self.assertEqual(str(tree('log(x, e)')), 'ln x')
self.assertEqual(str(tree('log(x, 10)')), 'log x')
self.assertEqual(str(tree('log(x, 2)')), 'log_2 x')
self.assertEqual(str(tree('log(x, g)')), 'log_g x')
def test_construct_function_integral(self): def test_construct_function_integral(self):
self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx') self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx')
self.assertEqual(str(tree('int x ^ 2 dx')), 'int x ^ 2 dx') self.assertEqual(str(tree('int x ^ 2 dx')), 'int x ^ 2 dx')
...@@ -271,8 +265,21 @@ class TestNode(RulesTestCase): ...@@ -271,8 +265,21 @@ class TestNode(RulesTestCase):
'[x ^ 2]_(a - b)^(a + b)') '[x ^ 2]_(a - b)^(a + b)')
def test_construct_function_absolute_child(self): def test_construct_function_absolute_child(self):
self.assertEqual(str(tree('ln(|x|)')), 'ln|x|') self.assertEqual(str(tree('ln(|x|)')), 'ln(|x|)')
self.assertEqual(str(tree('sin(|x|)')), 'sin|x|') self.assertEqual(str(tree('sin(|x|)')), 'sin(|x|)')
def test_construct_logarithm(self):
self.assertEqual(str(tree('log n')), 'log(n)')
self.assertEqual(str(tree('log(n)')), 'log(n)')
self.assertEqual(str(tree('ln n')), 'ln(n)')
self.assertEqual(str(tree('ln(n)')), 'ln(n)')
self.assertEqual(str(tree('log_2 n')), 'log_2(n)')
self.assertEqual(str(tree('log_2(n)')), 'log_2(n)')
self.assertEqual(str(tree('log_g n')), 'log_g(n)')
self.assertEqual(str(tree('log_(g + h) n')), 'log_(g + h)(n)')
def test_infinity(self): def test_infinity(self):
self.assertEqual(infinity(), tree('oo')) self.assertEqual(infinity(), tree('oo'))
......
...@@ -118,19 +118,19 @@ class TestRulesDerivatives(RulesTestCase): ...@@ -118,19 +118,19 @@ class TestRulesDerivatives(RulesTestCase):
def test_power_rule_chain(self): def test_power_rule_chain(self):
self.assertRewrite([ self.assertRewrite([
"[x ^ x]'", "[x ^ x]'",
"[e ^ (ln x ^ x)]'", "[e ^ (ln(x ^ x))]'",
"e ^ (ln x ^ x)[ln x ^ x]'", "e ^ (ln(x ^ x))[ln(x ^ x)]'",
"x ^ x * [ln x ^ x]'", "x ^ x * [ln(x ^ x)]'",
"x ^ x * [x ln x]'", "x ^ x * [x ln(x)]'",
"x ^ x * ([x]' * ln x + x[ln x]')", "x ^ x * ([x]' * ln(x) + x[ln(x)]')",
"x ^ x * (1ln x + x[ln x]')", "x ^ x * (1ln(x) + x[ln(x)]')",
"x ^ x * (ln x + x[ln x]')", "x ^ x * (ln(x) + x[ln(x)]')",
"x ^ x * (ln x + x * 1 / x)", "x ^ x * (ln(x) + x * 1 / x)",
"x ^ x * (ln x + (x * 1) / x)", "x ^ x * (ln(x) + (x * 1) / x)",
"x ^ x * (ln x + x / x)", "x ^ x * (ln(x) + x / x)",
"x ^ x * (ln x + 1)", "x ^ x * (ln(x) + 1)",
"x ^ x * ln x + x ^ x * 1", "x ^ x * ln(x) + x ^ x * 1",
"x ^ x * ln x + x ^ x", "x ^ x * ln(x) + x ^ x",
]) ])
def test_variable_root(self): def test_variable_root(self):
......
...@@ -142,9 +142,9 @@ class TestRulesIntegrals(RulesTestCase): ...@@ -142,9 +142,9 @@ class TestRulesIntegrals(RulesTestCase):
'int a / x', 'int a / x',
'int a * 1 / x dx', 'int a * 1 / x dx',
'a(int 1 / x dx)', 'a(int 1 / x dx)',
'a(ln|x| + C)', 'a(ln(|x|) + C)',
'a ln|x| + aC', 'a ln(|x|) + aC',
# FIXME: 'aln|x| + C', # ac -> C # FIXME: 'a ln(|x|) + C', # ac -> C
]) ])
def test_match_function_integral(self): def test_match_function_integral(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment