Skip to content
Snippets Groups Projects
Commit bbdcbe9c authored by Taddeüs Kroes's avatar Taddeüs Kroes
Browse files

Implemented correct parentheses parsing for simple functions

parent cc5e8713
No related branches found
No related tags found
No related merge requests found
...@@ -101,7 +101,7 @@ class Parser(BisonParser): ...@@ -101,7 +101,7 @@ class Parser(BisonParser):
# TODO: add a runtime check to verify that this token list match the list # TODO: add a runtime check to verify that this token list match the list
# of tokens of the lex script. # of tokens of the lex script.
tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH', tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'FUNCTION_PAREN',
'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \ 'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \
+ filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values()) + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
...@@ -121,6 +121,7 @@ class Parser(BisonParser): ...@@ -121,6 +121,7 @@ class Parser(BisonParser):
('nonassoc', ('NEG', )), ('nonassoc', ('NEG', )),
('nonassoc', ('FUNCTION', 'LOGARITHM')), ('nonassoc', ('FUNCTION', 'LOGARITHM')),
('right', ('POW', 'SUB')), ('right', ('POW', 'SUB')),
('nonassoc', ('FUNCTION_PAREN', )),
) )
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -502,6 +503,7 @@ class Parser(BisonParser): ...@@ -502,6 +503,7 @@ class Parser(BisonParser):
""" """
unary : MINUS exp %prec NEG unary : MINUS exp %prec NEG
| FUNCTION exp | FUNCTION exp
| FUNCTION_PAREN exp RPAREN
| raised_function exp %prec FUNCTION | raised_function exp %prec FUNCTION
| DERIVATIVE exp | DERIVATIVE exp
| exp PRIME | exp PRIME
...@@ -518,13 +520,15 @@ class Parser(BisonParser): ...@@ -518,13 +520,15 @@ class Parser(BisonParser):
return values[1] return values[1]
if option == 1: # rule: FUNCTION exp if option in (1, 2): # rule: FUNCTION exp | FUNCTION_PAREN exp RPAREN
fun = values[0] if option == 1 else values[0][:-1].rstrip()
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
return Node(values[0], *values[1]) return Node(fun, *values[1])
return Node(values[0], values[1]) return Node(fun, values[1])
if option == 2: # rule: raised_function exp if option == 3: # rule: raised_function exp
func, exponent = values[0] func, exponent = values[0]
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
...@@ -532,26 +536,26 @@ class Parser(BisonParser): ...@@ -532,26 +536,26 @@ class Parser(BisonParser):
return Node(OP_POW, Node(func, values[1]), exponent) return Node(OP_POW, Node(func, values[1]), exponent)
if option == 3: # rule: DERIVATIVE exp if option == 4: # rule: DERIVATIVE exp
# DERIVATIVE looks like 'd/d*x' -> extract the 'x' # DERIVATIVE looks like 'd/d*x' -> extract the 'x'
return Node(OP_DXDER, values[1], Leaf(values[0][-1])) return Node(OP_DXDER, values[1], Leaf(values[0][-1]))
if option == 4: # rule: exp PRIME if option == 5: # rule: exp PRIME
return Node(OP_PRIME, values[0]) return Node(OP_PRIME, values[0])
if option == 5: # rule: INTEGRAL exp if option == 6: # rule: INTEGRAL exp
fx, x = find_integration_variable(values[1]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x) return Node(OP_INT, fx, x)
if option == 6: # rule: integral_bounds exp if option == 7: # rule: integral_bounds exp
lbnd, ubnd = values[0] lbnd, ubnd = values[0]
fx, x = find_integration_variable(values[1]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x, lbnd, ubnd) return Node(OP_INT, fx, x, lbnd, ubnd)
if option == 7: # rule: PIPE exp PIPE if option == 8: # rule: PIPE exp PIPE
return Node(OP_ABS, values[1]) return Node(OP_ABS, values[1])
if option == 8: # rule: LOGARITHM exp if option == 9: # rule: LOGARITHM exp
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
return Node(OP_LOG, *values[1]) return Node(OP_LOG, *values[1])
...@@ -562,14 +566,14 @@ class Parser(BisonParser): ...@@ -562,14 +566,14 @@ class Parser(BisonParser):
return Node(OP_LOG, values[1], Leaf(base)) return Node(OP_LOG, values[1], Leaf(base))
if option == 9: # rule: logarithm_subscript exp if option == 10: # rule: logarithm_subscript exp
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
raise BisonSyntaxError('Shortcut logarithm base "log_%s" does ' raise BisonSyntaxError('Shortcut logarithm base "log_%s" does '
'not support additional arguments.' % (values[0])) 'not support additional arguments.' % (values[0]))
return Node(OP_LOG, values[1], values[0]) return Node(OP_LOG, values[1], values[0])
if option == 10: # rule: TIMES exp if option == 11: # rule: TIMES exp
return values[1] return values[1]
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
...@@ -715,8 +719,9 @@ class Parser(BisonParser): ...@@ -715,8 +719,9 @@ class Parser(BisonParser):
# Put all functions in a single regex # Put all functions in a single regex
if functions: if functions:
operators += '("%s") { returntoken(FUNCTION); }\n' \ fun_or = '("' + '"|"'.join(functions) + '")'
% '"|"'.join(functions) operators += fun_or + ' { returntoken(FUNCTION); }\n'
operators += fun_or + '[ ]*\( { returntoken(FUNCTION_PAREN); }\n'
# ----------------------------------------- # -----------------------------------------
# raw lex script, verbatim here # raw lex script, verbatim here
......
...@@ -88,19 +88,23 @@ class TestParser(RulesTestCase): ...@@ -88,19 +88,23 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('pi2'), tree('pi * 2')) self.assertEqual(tree('pi2'), tree('pi * 2'))
def test_functions(self): def test_function(self):
x = tree('x') x = tree('x')
self.assertEqual(tree('sin x'), sin(x)) self.assertEqual(tree('sin x'), sin(x))
self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct? self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct?
self.assertEqual(tree('sin x ^ 2'), sin(x ** 2)) self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
self.assertEqual(tree('sin^2 x'), sin(x) ** 2) self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2)) self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2)
self.assertEqual(tree('sin cos x'), sin(cos(x))) self.assertEqual(tree('sin cos x'), sin(cos(x)))
self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x)) ** 2)
self.assertEqual(tree('sin((cos x) ^ 2)'), sin(cos(x) ** 2))
def test_brackets(self): def test_brackets(self):
self.assertEqual(*tree('[x], x')) self.assertEqual(*tree('[x], x'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment