소스 검색

Fixed negation/substraction n-ary bug and added error location tracking.

Sander Mathijs van Veen 14 년 전
부모
커밋
4ffbcb2c05
3개의 변경된 파일58개의 추가작업 그리고 30개의 파일을 삭제
  1. 1 1
      external/pybison
  2. 11 11
      src/node.py
  3. 46 18
      src/parser.py

+ 1 - 1
external/pybison

@@ -1 +1 @@
-Subproject commit eb1d1da4c21cc3f48cabe19485381e3f7e80f279
+Subproject commit 930b44e0021b0ecac97c37ac8078867b84d09e98

+ 11 - 11
src/node.py

@@ -37,10 +37,10 @@ TYPE_MAP = {
         str: TYPE_IDENTIFIER,
         }
 
-
-OPT_MAP = {
+OP_MAP = {
         '+': OP_ADD,
-        '-': OP_SUB,
+        # Either substitution or negation. Skip the operator sign in 'x' (= 2).
+        '-': lambda x: OP_SUB if len(x) > 2 else OP_NEG,
         '*': OP_MUL,
         '/': OP_DIV,
         '^': OP_POW,
@@ -54,7 +54,10 @@ class ExpressionNode(Node):
     def __init__(self, *args, **kwargs):
         super(ExpressionNode, self).__init__(*args, **kwargs)
         self.type = TYPE_OPERATOR
-        self.opt = OPT_MAP[args[0]]
+        self.op = OP_MAP[args[0]]
+
+        if hasattr(self.op, '__call__'):
+            self.op = self.op(args)
 
     def __str__(self):  # pragma: nocover
         return generate_line(self)
@@ -69,10 +72,10 @@ class ExpressionNode(Node):
         self.parent = None
 
     def is_power(self):
-        return self.opt == OP_POW
+        return self.op == OP_POW
 
     def is_nary(self):
-        return self.opt in [OP_ADD, OP_SUB, OP_MUL]
+        return self.op in [OP_ADD, OP_SUB, OP_MUL]
 
     def get_order(self):
         if self.is_power() and self[0].is_identifier() \
@@ -91,7 +94,7 @@ class ExpressionNode(Node):
         scope = []
 
         for child in self:
-            if not isinstance(child, Leaf) and child.opt == self.opt:
+            if not isinstance(child, Leaf) and child.op == self.op:
                 scope += child.get_scope()
             else:
                 scope.append(child)
@@ -103,10 +106,7 @@ class ExpressionLeaf(Leaf):
     def __init__(self, *args, **kwargs):
         super(ExpressionLeaf, self).__init__(*args, **kwargs)
 
-        for data_type, type_repr in TYPE_MAP.iteritems():
-            if isinstance(args[0], data_type):
-                self.type = type_repr
-                break
+        self.type = TYPE_MAP[type(args[0])]
 
     def get_order(self):
         if self.is_identifier():

+ 46 - 18
src/parser.py

@@ -19,10 +19,30 @@ sys.path.insert(1, EXTERNAL_MODS)
 from pybison import BisonParser, BisonSyntaxError
 from graph_drawing.graph import generate_graph
 
+from node import TYPE_OPERATOR, OP_ADD, OP_MUL, OP_SUB, OP_NEG
+
 
 # Check for n-ary operator in child nodes
-def combine(op, n):
-    return n.nodes if n.title() == op else [n]
+def combine(op, op_type, *nodes):
+    # At least return the operator.
+    res = [op]
+
+    for n in nodes:
+        try:
+            assert n.type != TYPE_OPERATOR or n.op != OP_NEG or len(n.nodes) == 1
+            assert n.type != TYPE_OPERATOR or n.op != OP_SUB or len(n.nodes) > 1
+        except AssertionError:
+            print n, type(n), n.type, OP_NEG, OP_SUB, n.nodes, len(n.nodes)
+            raise
+
+        # Merge the children for all nodes which have the same operator.
+        if n.type == TYPE_OPERATOR and n.op == op_type:
+            res += n.nodes
+        else:
+            res.append(n)
+
+    return res
+
 
 class Parser(BisonParser):
     """
@@ -68,7 +88,7 @@ class Parser(BisonParser):
             return ''
 
         try:
-            return raw_input('>>> ') + '\n'
+            return raw_input('>>> ' if self.interactive else '') + '\n'
         except EOFError:
             return ''
 
@@ -235,15 +255,13 @@ class Parser(BisonParser):
         """
 
         if option == 0:  # rule: exp PLUS exp
-            return Node('+', *(combine('+', values[0])
-                               + combine('+', values[2])))
+            return Node(*(combine('+', OP_ADD, values[0], values[2])))
 
         if option == 1:  # rule: exp MINUS exp
-            return Node('-', values[0], values[2])
+            return Node(*(combine('-', OP_SUB, values[0], values[2])))
 
         if option == 2:  # rule: exp TIMES exp
-            return Node('*', *(combine('*', values[0])
-                               + combine('*', values[2])))
+            return Node(*(combine('*', OP_MUL, values[0], values[2])))
 
         if option == 3:  # rule: exp DIVIDE exp
             return Node('/', values[0], values[2])
@@ -293,7 +311,6 @@ class Parser(BisonParser):
     # -----------------------------------------
     lexscript = r"""
     %{
-    //int yylineno = 0;
     #include "Python.h"
     #define YYSTYPE void *
     #include "tokens.h"
@@ -301,13 +318,24 @@ class Parser(BisonParser):
     extern void (*py_input)(PyObject *parser, char *buf, int *result,
                             int max_size);
     #define returntoken(tok) \
-        yylval = PyString_FromString(strdup(yytext)); return (tok);
+            yylval = PyString_FromString(strdup(yytext)); return (tok);
     #define YY_INPUT(buf,result,max_size) { \
-        (*py_input)(py_parser, buf, &result, max_size); \
+            (*py_input)(py_parser, buf, &result, max_size); \
     }
+
+    int yycolumn = 0;
+
+    #define YY_USER_ACTION \
+            yylloc.first_line = yylloc.last_line = yylineno; \
+            yylloc.first_column = yycolumn; \
+            yylloc.last_column = yycolumn + yyleng; \
+            yycolumn += yyleng;
+
     /*[a-zA-Z][0-9]+ { returntoken(CONCAT_POW); }*/
     %}
 
+    %option yylineno
+
     %%
 
     [0-9]+    { returntoken(NUMBER); }
@@ -320,14 +348,13 @@ class Parser(BisonParser):
     "^"       { returntoken(POW); }
     "/"       { returntoken(DIVIDE); }
     ","       { returntoken(COMMA); }
-    "quit"    { printf("lex: got QUIT\n"); yyterminate(); returntoken(QUIT); }
+    "quit"    { yyterminate(); returntoken(QUIT); }
     "raise"   { returntoken(RAISE); }
     "graph"   { returntoken(GRAPH); }
 
-    [ \t\v\f] {}
-    [\n]      {yylineno++; returntoken(NEWLINE); }
-    .         { printf("unknown char %c ignored, yytext=%p\n",
-                yytext[0], yytext); /* ignore bad chars */}
+    [ \t\v\f] { }
+    [\n]      { yycolumn = 0; returntoken(NEWLINE); }
+    .         { printf("unknown char %c ignored.\n", yytext[0]); }
 
     %%
 
@@ -350,15 +377,16 @@ def get_args():
 
 def main():
     args = get_args()
+    interactive = not args.batch and sys.stdin.isatty()
 
     p = Parser(verbose=args.verbose,
                keepfiles=args.keepfiles,
-               interactive=not args.batch)
+               interactive=interactive)
 
     node = p.run(debug=args.debug)
 
     # Clear the line, when the shell exits.
-    if not args.batch:
+    if interactive:
         print
 
     return node