Pārlūkot izejas kodu

Implemented correct parentheses parsing for simple functions

Taddeus Kroes 13 gadi atpakaļ
vecāks
revīzija
bbdcbe9c23
2 mainītis faili ar 27 papildinājumiem un 18 dzēšanām
  1. 20 15
      src/parser.py
  2. 7 3
      tests/test_parser.py

+ 20 - 15
src/parser.py

@@ -101,7 +101,7 @@ class Parser(BisonParser):
     # TODO: add a runtime check to verify that this token list match the list
     # of tokens of the lex script.
     tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
-              'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET',
+              'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'FUNCTION_PAREN',
               'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \
               + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
 
@@ -121,6 +121,7 @@ class Parser(BisonParser):
         ('nonassoc', ('NEG', )),
         ('nonassoc', ('FUNCTION', 'LOGARITHM')),
         ('right', ('POW', 'SUB')),
+        ('nonassoc', ('FUNCTION_PAREN', )),
         )
 
     def __init__(self, **kwargs):
@@ -502,6 +503,7 @@ class Parser(BisonParser):
         """
         unary : MINUS exp %prec NEG
               | FUNCTION exp
+              | FUNCTION_PAREN exp RPAREN
               | raised_function exp %prec FUNCTION
               | DERIVATIVE exp
               | exp PRIME
@@ -518,13 +520,15 @@ class Parser(BisonParser):
 
             return values[1]
 
-        if option == 1:  # rule: FUNCTION exp
+        if option in (1, 2):  # rule: FUNCTION exp | FUNCTION_PAREN exp RPAREN
+            fun = values[0] if option == 1 else values[0][:-1].rstrip()
+
             if values[1].is_op(OP_COMMA):
-                return Node(values[0], *values[1])
+                return Node(fun, *values[1])
 
-            return Node(values[0], values[1])
+            return Node(fun, values[1])
 
-        if option == 2:  # rule: raised_function exp
+        if option == 3:  # rule: raised_function exp
             func, exponent = values[0]
 
             if values[1].is_op(OP_COMMA):
@@ -532,26 +536,26 @@ class Parser(BisonParser):
 
             return Node(OP_POW, Node(func, values[1]), exponent)
 
-        if option == 3:  # rule: DERIVATIVE exp
+        if option == 4:  # rule: DERIVATIVE exp
             # DERIVATIVE looks like 'd/d*x' -> extract the 'x'
             return Node(OP_DXDER, values[1], Leaf(values[0][-1]))
 
-        if option == 4:  # rule: exp PRIME
+        if option == 5:  # rule: exp PRIME
             return Node(OP_PRIME, values[0])
 
-        if option == 5:  # rule: INTEGRAL exp
+        if option == 6:  # rule: INTEGRAL exp
             fx, x = find_integration_variable(values[1])
             return Node(OP_INT, fx, x)
 
-        if option == 6:  # rule: integral_bounds exp
+        if option == 7:  # rule: integral_bounds exp
             lbnd, ubnd = values[0]
             fx, x = find_integration_variable(values[1])
             return Node(OP_INT, fx, x, lbnd, ubnd)
 
-        if option == 7:  # rule: PIPE exp PIPE
+        if option == 8:  # rule: PIPE exp PIPE
             return Node(OP_ABS, values[1])
 
-        if option == 8:  # rule: LOGARITHM exp
+        if option == 9:  # rule: LOGARITHM exp
             if values[1].is_op(OP_COMMA):
                 return Node(OP_LOG, *values[1])
 
@@ -562,14 +566,14 @@ class Parser(BisonParser):
 
             return Node(OP_LOG, values[1], Leaf(base))
 
-        if option == 9:  # rule: logarithm_subscript exp
+        if option == 10:  # rule: logarithm_subscript exp
             if values[1].is_op(OP_COMMA):
                 raise BisonSyntaxError('Shortcut logarithm base "log_%s" does '
                         'not support additional arguments.' % (values[0]))
 
             return Node(OP_LOG, values[1], values[0])
 
-        if option == 10:  # rule: TIMES exp
+        if option == 11:  # rule: TIMES exp
             return values[1]
 
         raise BisonSyntaxError('Unsupported option %d in target "%s".'
@@ -715,8 +719,9 @@ class Parser(BisonParser):
 
     # Put all functions in a single regex
     if functions:
-        operators += '("%s") { returntoken(FUNCTION); }\n' \
-                     % '"|"'.join(functions)
+        fun_or = '("' + '"|"'.join(functions) + '")'
+        operators += fun_or + ' { returntoken(FUNCTION); }\n'
+        operators += fun_or + '[ ]*\( { returntoken(FUNCTION_PAREN); }\n'
 
     # -----------------------------------------
     # raw lex script, verbatim here

+ 7 - 3
tests/test_parser.py

@@ -88,19 +88,23 @@ class TestParser(RulesTestCase):
 
         self.assertEqual(tree('pi2'), tree('pi * 2'))
 
-    def test_functions(self):
+    def test_function(self):
         x = tree('x')
 
         self.assertEqual(tree('sin x'), sin(x))
         self.assertEqual(tree('sin 2 x'), sin(2) * x)  # FIXME: correct?
         self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
         self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
+
         self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
+        self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2)
 
         self.assertEqual(tree('sin cos x'), sin(cos(x)))
         self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
-        self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2)))
-        self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
+        self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2))
+
+        self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x)) ** 2)
+        self.assertEqual(tree('sin((cos x) ^ 2)'), sin(cos(x) ** 2))
 
     def test_brackets(self):
         self.assertEqual(*tree('[x], x'))