Commit bbdcbe9c authored by Taddeüs Kroes's avatar Taddeüs Kroes

Implemented correct parentheses parsing for simple functions

parent cc5e8713
......@@ -101,7 +101,7 @@ class Parser(BisonParser):
# TODO: add a runtime check to verify that this token list match the list
# of tokens of the lex script.
tokens = ['NUMBER', 'IDENTIFIER', 'NEWLINE', 'QUIT', 'RAISE', 'GRAPH',
'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET',
'LPAREN', 'RPAREN', 'FUNCTION', 'LBRACKET', 'FUNCTION_PAREN',
'RBRACKET', 'LCBRACKET', 'RCBRACKET', 'PIPE'] \
+ filter(lambda t: t != 'FUNCTION', TOKEN_MAP.values())
......@@ -121,6 +121,7 @@ class Parser(BisonParser):
('nonassoc', ('NEG', )),
('nonassoc', ('FUNCTION', 'LOGARITHM')),
('right', ('POW', 'SUB')),
('nonassoc', ('FUNCTION_PAREN', )),
)
def __init__(self, **kwargs):
......@@ -502,6 +503,7 @@ class Parser(BisonParser):
"""
unary : MINUS exp %prec NEG
| FUNCTION exp
| FUNCTION_PAREN exp RPAREN
| raised_function exp %prec FUNCTION
| DERIVATIVE exp
| exp PRIME
......@@ -518,13 +520,15 @@ class Parser(BisonParser):
return values[1]
if option == 1: # rule: FUNCTION exp
if option in (1, 2): # rule: FUNCTION exp | FUNCTION_PAREN exp RPAREN
fun = values[0] if option == 1 else values[0][:-1].rstrip()
if values[1].is_op(OP_COMMA):
return Node(values[0], *values[1])
return Node(fun, *values[1])
return Node(values[0], values[1])
return Node(fun, values[1])
if option == 2: # rule: raised_function exp
if option == 3: # rule: raised_function exp
func, exponent = values[0]
if values[1].is_op(OP_COMMA):
......@@ -532,26 +536,26 @@ class Parser(BisonParser):
return Node(OP_POW, Node(func, values[1]), exponent)
if option == 3: # rule: DERIVATIVE exp
if option == 4: # rule: DERIVATIVE exp
# DERIVATIVE looks like 'd/d*x' -> extract the 'x'
return Node(OP_DXDER, values[1], Leaf(values[0][-1]))
if option == 4: # rule: exp PRIME
if option == 5: # rule: exp PRIME
return Node(OP_PRIME, values[0])
if option == 5: # rule: INTEGRAL exp
if option == 6: # rule: INTEGRAL exp
fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x)
if option == 6: # rule: integral_bounds exp
if option == 7: # rule: integral_bounds exp
lbnd, ubnd = values[0]
fx, x = find_integration_variable(values[1])
return Node(OP_INT, fx, x, lbnd, ubnd)
if option == 7: # rule: PIPE exp PIPE
if option == 8: # rule: PIPE exp PIPE
return Node(OP_ABS, values[1])
if option == 8: # rule: LOGARITHM exp
if option == 9: # rule: LOGARITHM exp
if values[1].is_op(OP_COMMA):
return Node(OP_LOG, *values[1])
......@@ -562,14 +566,14 @@ class Parser(BisonParser):
return Node(OP_LOG, values[1], Leaf(base))
if option == 9: # rule: logarithm_subscript exp
if option == 10: # rule: logarithm_subscript exp
if values[1].is_op(OP_COMMA):
raise BisonSyntaxError('Shortcut logarithm base "log_%s" does '
'not support additional arguments.' % (values[0]))
return Node(OP_LOG, values[1], values[0])
if option == 10: # rule: TIMES exp
if option == 11: # rule: TIMES exp
return values[1]
raise BisonSyntaxError('Unsupported option %d in target "%s".'
......@@ -715,8 +719,9 @@ class Parser(BisonParser):
# Put all functions in a single regex
if functions:
operators += '("%s") { returntoken(FUNCTION); }\n' \
% '"|"'.join(functions)
fun_or = '("' + '"|"'.join(functions) + '")'
operators += fun_or + ' { returntoken(FUNCTION); }\n'
operators += fun_or + '[ ]*\( { returntoken(FUNCTION_PAREN); }\n'
# -----------------------------------------
# raw lex script, verbatim here
......
......@@ -88,19 +88,23 @@ class TestParser(RulesTestCase):
self.assertEqual(tree('pi2'), tree('pi * 2'))
def test_functions(self):
def test_function(self):
x = tree('x')
self.assertEqual(tree('sin x'), sin(x))
self.assertEqual(tree('sin 2 x'), sin(2) * x) # FIXME: correct?
self.assertEqual(tree('sin x ^ 2'), sin(x ** 2))
self.assertEqual(tree('sin^2 x'), sin(x) ** 2)
self.assertEqual(tree('sin(x ^ 2)'), sin(x ** 2))
self.assertEqual(tree('sin(x) ^ 2'), sin(x) ** 2)
self.assertEqual(tree('sin cos x'), sin(cos(x)))
self.assertEqual(tree('sin cos x ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x ** 2)))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin cos(x) ^ 2'), sin(cos(x) ** 2))
self.assertEqual(tree('sin (cos x) ^ 2'), sin(cos(x)) ** 2)
self.assertEqual(tree('sin((cos x) ^ 2)'), sin(cos(x) ** 2))
def test_brackets(self):
self.assertEqual(*tree('[x], x'))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment