Commit bbdcbe9c authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented correct parentheses parsing for simple functions

parent cc5e8713
...@@ -101,7 +101,7 @@ class Parser(BisonParser): ...@@ -101,7 +101,7 @@ class Parser(BisonParser):
# TODO: add a runtime check to verify that this token list match the list # TODO: add a runtime check to verify that this token list match the list
# of tokens of the lex script. # of tokens of the lex script.
tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH', tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'FUNCTION_PAREN',
'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \ 'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \
+ filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values()) + filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
...@@ -121,6 +121,7 @@ class Parser(BisonParser): ...@@ -121,6 +121,7 @@ class Parser(BisonParser):
('nonassoc', ('NEG', )), ('nonassoc', ('NEG', )),
('nonassoc', ('FUNCTION', 'LOGARITHM')), ('nonassoc', ('FUNCTION', 'LOGARITHM')),
('right', ('POW', 'SUB')), ('right', ('POW', 'SUB')),
('nonassoc', ('FUNCTION_PAREN', )),
) )
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -502,6 +503,7 @@ class Parser(BisonParser): ...@@ -502,6 +503,7 @@ class Parser(BisonParser):
""" """
unary : MINUS exp %prec NEG unary : MINUS exp %prec NEG
| FUNCTION exp | FUNCTION exp
| FUNCTION_PAREN exp RPAREN
| raised_function exp %prec FUNCTION | raised_function exp %prec FUNCTION
| DERIVATIVE exp | DERIVATIVE exp
| exp PRIME | exp PRIME
...@@ -518,13 +520,15 @@ class Parser(BisonParser): ...@@ -518,13 +520,15 @@ class Parser(BisonParser):
return values[1] return values[1]
if option == 1: # rule: FUNCTION exp if option in (1, 2): # rule: FUNCTION exp | FUNCTION_PAREN exp RPAREN
fun = values[0] if option == 1 else values[0][:-1].rstrip()
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
return Node(values[0], *values[1]) return Node(fun, *values[1])
return Node(values[0], values[1]) return Node(fun, values[1])
if option == 2: # rule: raised_function exp if option == 3: # rule: raised_function exp
func, exponent = values[0] func, exponent = values[0]
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
...@@ -532,26 +536,26 @@ class Parser(BisonParser): ...@@ -532,26 +536,26 @@ class Parser(BisonParser):
return Node(OP_POW, Node(func, values[1]), exponent) return Node(OP_POW, Node(func, values[1]), exponent)
if option == 3: # rule: DERIVATIVE exp if option == 4: # rule: DERIVATIVE exp
# DERIVATIVE looks like 'd/d*x' -> extract the 'x' # DERIVATIVE looks like 'd/d*x' -> extract the 'x'
return Node(OP_DXDER, values[1], Leaf(values[0][-1])) return Node(OP_DXDER, values[1], Leaf(values[0][-1]))
if option == 4: # rule: exp PRIME if option == 5: # rule: exp PRIME
return Node(OP_PRIME, values[0]) return Node(OP_PRIME, values[0])
if option == 5: # rule: INTEGRAL exp if option == 6: # rule: INTEGRAL exp
fx, x = find_integration_variable(values[1]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x) return Node(OP_INT, fx, x)
if option == 6: # rule: integral_bounds exp if option == 7: # rule: integral_bounds exp
lbnd, ubnd = values[0] lbnd, ubnd = values[0]
fx, x = find_integration_variable(values[1]) fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x, lbnd, ubnd) return Node(OP_INT, fx, x, lbnd, ubnd)
if option == 7: # rule: PIPE exp PIPE if option == 8: # rule: PIPE exp PIPE
return Node(OP_ABS, values[1]) return Node(OP_ABS, values[1])
if option == 8: # rule: LOGARITHM exp if option == 9: # rule: LOGARITHM exp
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
return Node(OP_LOG, *values[1]) return Node(OP_LOG, *values[1])
...@@ -562,14 +566,14 @@ class Parser(BisonParser): ...@@ -562,14 +566,14 @@ class Parser(BisonParser):
return Node(OP_LOG, values[1], Leaf(base)) return Node(OP_LOG, values[1], Leaf(base))
if option == 9: # rule: logarithm_subscript exp if option == 10: # rule: logarithm_subscript exp
if values[1].is_op(OP_COMMA): if values[1].is_op(OP_COMMA):
raise BisonSyntaxError('Shortcut logarithm base "log_%s" does ' raise BisonSyntaxError('Shortcut logarithm base "log_%s" does '
'not support additional arguments.' % (values[0])) 'not support additional arguments.' % (values[0]))
return Node(OP_LOG, values[1], values[0]) return Node(OP_LOG, values[1], values[0])
if option == 10: # rule: TIMES exp if option == 11: # rule: TIMES exp
return values[1] return values[1]
raise BisonSyntaxError('Unsupported option %d in target "%s".' raise BisonSyntaxError('Unsupported option %d in target "%s".'
...@@ -715,8 +719,9 @@ class Parser(BisonParser): ...@@ -715,8 +719,9 @@ class Parser(BisonParser):
# Put all functions in a single regex # Put all functions in a single regex
if functions: if functions:
operators += '("%s") { returntoken(FUNCTION); }\n' \ fun_or = '("' + '"|"'.join(functions) + '")'
% '"|"'.join(functions) operators += fun_or + ' { returntoken(FUNCTION); }\n'
operators += fun_or + '[ ]*\( { returntoken(FUNCTION_PAREN); }\n'
# ----------------------------------------- # -----------------------------------------
# raw lex script, verbatim here # raw lex script, verbatim here
......
...@@ -88,19 +88,23 @@ class TestParser(RulesTestCase): ...@@ -88,19 +88,23 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('pi2'), tree('pi * 2')) self.assertEqual(tree('pi2'), tree('pi * 2'))
def test_functions(self): def test_function(self):
x = tree('x') x = tree('x')
self.assertEqual(tree('sin x'), sin(x)) self.assertEqual(tree('sin x'), sin(x))
self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct? self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct?
self.assertEqual(tree('sin x ^ 2'), sin(x ** 2)) self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
self.assertEqual(tree('sin^2 x'), sin(x) ** 2) self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2)) self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2)
self.assertEqual(tree('sin cos x'), sin(cos(x))) self.assertEqual(tree('sin cos x'), sin(cos(x)))
self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2))) self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x)) ** 2)
self.assertEqual(tree('sin((cos x) ^ 2)'), sin(cos(x) ** 2))
def test_brackets(self): def test_brackets(self):
self.assertEqual(*tree('[x], x')) self.assertEqual(*tree('[x], x'))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment