Bladeren bron

Implemented correct parentheses parsing for logarithms

Taddeus Kroes 13 jaren geleden
bovenliggende
commit
e9dd48f86f
4 gewijzigde bestanden met toevoegingen van 52 en 29 verwijderingen
  1. 21 5
      src/node.py
  2. 15 8
      tests/test_node.py
  3. 13 13
      tests/test_rules_derivatives.py
  4. 3 3
      tests/test_rules_integrals.py

+ 21 - 5
src/node.py

@@ -16,6 +16,7 @@
 import os.path
 import sys
 import copy
+import re
 
 sys.path.insert(0, os.path.realpath('external'))
 
@@ -324,22 +325,33 @@ class ExpressionNode(Node, ExpressionBase):
         if self.op in UNARY_FUNCTIONS:
             return 1
 
-        if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE):
-            return 1
+        #if self.op == OP_LOG and self[1].value in (E, DEFAULT_LOGARITHM_BASE):
+        #    return 1
+
+        # Functions always have parentheses
+        if self.op in TOKEN_MAP and TOKEN_MAP[self.op] == 'FUNCTION':
+            return 2
 
         return len(self)
 
     def operator(self):
+        # Append an opening parenthesis manually, the closing parentheses is
+        # appended by postprocess_str
         if self.op == OP_LOG:
             base = self[1].value
 
             if base == DEFAULT_LOGARITHM_BASE:
-                return self.value
+                return self.value + '('
 
             if base == E:
-                return 'ln'
+                return 'ln('
+
+            base = str(self[1])
 
-            return '%s_%s' % (self.value, str(self[1]))
+            if not re.match('^[0-9]+|[a-zA-Z]$', base):
+                base = '(' + base + ')'
+
+            return '%s_%s(' % (self.value, base)
 
         if self.op == OP_DXDER:
             return self.value + str(self[1])
@@ -365,6 +377,10 @@ class ExpressionNode(Node, ExpressionBase):
             self[0] = ExpressionNode(OP_BRACKETS, self[0])
 
     def postprocess_str(self, s):
+        # A bit hacky, but forced because of operator() method
+        if self.op == OP_LOG:
+            return s.replace('( ', '(') + ')'
+
         if self.op == OP_INT:
             return '%s d%s' % (s, self[1])
 

+ 15 - 8
tests/test_node.py

@@ -248,12 +248,6 @@ class TestNode(RulesTestCase):
         self.assertEqual(str(tree("(x ^ 2)''")), "[x ^ 2]''")
         self.assertEqual(str(tree('d/dx x ^ 2')), 'd/dx x ^ 2')
 
-    def test_construct_function_logarithm(self):
-        self.assertEqual(str(tree('log(x, e)')), 'ln x')
-        self.assertEqual(str(tree('log(x, 10)')), 'log x')
-        self.assertEqual(str(tree('log(x, 2)')), 'log_2 x')
-        self.assertEqual(str(tree('log(x, g)')), 'log_g x')
-
     def test_construct_function_integral(self):
         self.assertEqual(str(tree('int x ^ 2')), 'int x ^ 2 dx')
         self.assertEqual(str(tree('int x ^ 2 dx')), 'int x ^ 2 dx')
@@ -271,8 +265,21 @@ class TestNode(RulesTestCase):
                          '[x ^ 2]_(a - b)^(a + b)')
 
     def test_construct_function_absolute_child(self):
-        self.assertEqual(str(tree('ln(|x|)')), 'ln|x|')
-        self.assertEqual(str(tree('sin(|x|)')), 'sin|x|')
+        self.assertEqual(str(tree('ln(|x|)')), 'ln(|x|)')
+        self.assertEqual(str(tree('sin(|x|)')), 'sin(|x|)')
+
+    def test_construct_logarithm(self):
+        self.assertEqual(str(tree('log n')), 'log(n)')
+        self.assertEqual(str(tree('log(n)')), 'log(n)')
+
+        self.assertEqual(str(tree('ln n')), 'ln(n)')
+        self.assertEqual(str(tree('ln(n)')), 'ln(n)')
+
+        self.assertEqual(str(tree('log_2 n')), 'log_2(n)')
+        self.assertEqual(str(tree('log_2(n)')), 'log_2(n)')
+
+        self.assertEqual(str(tree('log_g n')), 'log_g(n)')
+        self.assertEqual(str(tree('log_(g + h) n')), 'log_(g + h)(n)')
 
     def test_infinity(self):
         self.assertEqual(infinity(), tree('oo'))

+ 13 - 13
tests/test_rules_derivatives.py

@@ -118,19 +118,19 @@ class TestRulesDerivatives(RulesTestCase):
     def test_power_rule_chain(self):
         self.assertRewrite([
             "[x ^ x]'",
-            "[e ^ (ln x ^ x)]'",
-            "e ^ (ln x ^ x)[ln x ^ x]'",
-            "x ^ x * [ln x ^ x]'",
-            "x ^ x * [x ln x]'",
-            "x ^ x * ([x]' * ln x + x[ln x]')",
-            "x ^ x * (1ln x + x[ln x]')",
-            "x ^ x * (ln x + x[ln x]')",
-            "x ^ x * (ln x + x * 1 / x)",
-            "x ^ x * (ln x + (x * 1) / x)",
-            "x ^ x * (ln x + x / x)",
-            "x ^ x * (ln x + 1)",
-            "x ^ x * ln x + x ^ x * 1",
-            "x ^ x * ln x + x ^ x",
+            "[e ^ (ln(x ^ x))]'",
+            "e ^ (ln(x ^ x))[ln(x ^ x)]'",
+            "x ^ x * [ln(x ^ x)]'",
+            "x ^ x * [x ln(x)]'",
+            "x ^ x * ([x]' * ln(x) + x[ln(x)]')",
+            "x ^ x * (1ln(x) + x[ln(x)]')",
+            "x ^ x * (ln(x) + x[ln(x)]')",
+            "x ^ x * (ln(x) + x * 1 / x)",
+            "x ^ x * (ln(x) + (x * 1) / x)",
+            "x ^ x * (ln(x) + x / x)",
+            "x ^ x * (ln(x) + 1)",
+            "x ^ x * ln(x) + x ^ x * 1",
+            "x ^ x * ln(x) + x ^ x",
         ])
 
     def test_variable_root(self):

+ 3 - 3
tests/test_rules_integrals.py

@@ -142,9 +142,9 @@ class TestRulesIntegrals(RulesTestCase):
             'int a / x',
             'int a * 1 / x dx',
             'a(int 1 / x dx)',
-            'a(ln|x| + C)',
-            'a ln|x| + aC',
-            # FIXME: 'aln|x| + C',  # ac -> C
+            'a(ln(|x|) + C)',
+            'a ln(|x|) + aC',
+            # FIXME: 'a ln(|x|) + C',  # ac -> C
         ])
 
     def test_match_function_integral(self):