Ver Fonte

Implemented scalar array initialisation

Taddeus Kroes há 12 anos atrás
pai
commit
f282d9b512
7 ficheiros alterados com 114 adições e 91 exclusões
  1. 1 1
      ast.ml
  2. 1 0
      main.ml
  3. 5 0
      phases/context_analysis.ml
  4. 70 70
      phases/desug.ml
  5. 31 17
      phases/typecheck.ml
  6. 3 2
      stringify.ml
  7. 3 1
      util.ml

+ 1 - 1
ast.ml

@@ -37,7 +37,6 @@ and node =
     | IntConst of int * loc
     | FloatConst of float * loc
     | ArrayConst of node list * loc
-    | ArrayScalar of node * loc
     | Var of string * loc
     | Deref of string * node list * loc
     | Monop of operator * node * loc
@@ -47,6 +46,7 @@ and node =
     | Arg of node
 
     (* additional types for convenience in traversals *)
+    | ArrayScalar of node * ctype
     | Cond of node * node * node * loc
     | VarLet of node * ctype * int
     | VarUse of node * ctype * int

+ 1 - 0
main.ml

@@ -18,6 +18,7 @@ let compile args =
         Desug.phase;
         Print.phase;
         Context_analysis.phase;
+        Print.phase;
         Typecheck.phase;
         Print.phase;
         Expand_dims.phase;

+ 5 - 0
phases/context_analysis.ml

@@ -75,6 +75,11 @@ let rec analyse scope depth args node =
             let (decl, dec_depth) = check_in_scope (Varname name) node scope in
             VarUse (node, ctypeof decl, depth - dec_depth)
 
+        | Deref (name, dims, loc) ->
+            let (decl, dec_depth) = check_in_scope (Varname name) node scope in
+            let node = Deref (name, List.map collect dims, loc) in
+            VarUse (node, ctypeof decl, depth - dec_depth)
+
         | FunCall (name, args, loc) ->
             let (decl, dec_depth) = check_in_scope (Funcname name) node scope in
             let node = FunCall (name, transform_all collect args, loc) in

+ 70 - 70
phases/desug.ml

@@ -1,53 +1,6 @@
 open Ast
 open Util
 
-let rec replace_var var replacement node =
-    let trav = (replace_var var replacement) in
-    match node with
-    | Var (name, loc) when name = var ->
-        Var (replacement, loc)
-    | For (counter, start, stop, step, body, loc) when counter = var ->
-        For (replacement, trav start, trav stop, trav step, trav body, loc)
-    | node ->
-        transform_children trav node
-
-let for_to_while node =
-    let new_vars = ref [] in
-    let rec traverse = function
-        (* Do not traverse into local functions (already done by var_init) *)
-        | FunDef (_, _, _, _, _, _) as node -> node
-
-        (* Transform for-loops to while-loops *)
-        | For (counter, start, stop, step, body, loc) ->
-            let _i = fresh_var counter in
-            let _stop = fresh_var "stop" in
-            let _step = fresh_var "step" in
-            new_vars := !new_vars @ [_i; _stop; _step];
-
-            let vi = Var (_i, noloc) in
-            let vstop = Var (_stop, locof stop) in
-            let vstep = Var (_step, locof step) in
-            let cond = Cond (
-                Binop (Gt, vstep, IntConst (0, noloc), noloc),
-                Binop (Lt, vi, vstop, noloc),
-                Binop (Gt, vi, vstop, noloc),
-                noloc
-            ) in
-            Block [
-                Assign (_i, None, start, locof start);
-                Assign (_stop, None, stop, locof stop);
-                Assign (_step, None, step, locof step);
-                While (cond, traverse (Block (
-                    (* TODO: check for illegal assigments of counter in body *)
-                    block_body (replace_var counter _i body) @
-                    [Assign (_i, None, Binop (Add, vi, vstep, noloc), noloc)]
-                )), loc);
-            ]
-
-        | node -> transform_children traverse node
-    in
-    (traverse node, new_vars)
-
 let rec var_init = function
     (* Move global initialisations to __init function *)
     | Program (decls, loc) ->
@@ -78,13 +31,13 @@ let rec var_init = function
             let rec trav inits node = match node with
                 (* translate scalar array initialisation to ArrayScalar node,
                  * for easy replacement later on *)
-                | VarDec (ArrayDef (_, _) as vtype, name,
-                          Some ((BoolConst  (_, l)) as v), loc) :: t
-                | VarDec (ArrayDef (_, _) as vtype, name,
-                          Some ((FloatConst (_, l)) as v), loc) :: t
-                | VarDec (ArrayDef (_, _) as vtype, name,
-                          Some ((IntConst   (_, l)) as v), loc) :: t ->
-                    trav inits (VarDec (vtype, name, Some (ArrayScalar (v, l)), loc) :: t)
+                | VarDec (ArrayDef _ as vtype, name,
+                          Some (BoolConst _  as v), loc) :: t
+                | VarDec (ArrayDef _ as vtype, name,
+                          Some (FloatConst _ as v), loc) :: t
+                | VarDec (ArrayDef _ as vtype, name,
+                          Some (IntConst _   as v), loc) :: t ->
+                    trav inits (VarDec (vtype, name, Some (ArrayScalar (v, vtype)), loc) :: t)
 
                 | VarDec (ctype, name, init, loc) as dec :: tl ->
                     (* array definition: create __allocate statement *)
@@ -109,36 +62,83 @@ let rec var_init = function
             flatten_blocks (trav [] body)
         in
         let params = flatten_blocks (List.map var_init params) in
-        let (body, new_vars) = for_to_while (Block (move_inits body)) in
-        let create_vardec name = VarDec (Int, name, None, noloc) in
-        let new_vardecs = List.map create_vardec !new_vars in
-        let stats = new_vardecs @ (flatten_blocks (block_body body)) in
-        FunDef (export, ret_type, name, params, Block stats, loc)
+        FunDef (export, ret_type, name, params, Block (move_inits body), loc)
 
     | node -> transform_children var_init node
 
-(*
+let rec replace_var var replacement node =
+    let trav = (replace_var var replacement) in
+    match node with
+    | Var (name, loc) when name = var ->
+        Var (replacement, loc)
+    | For (counter, start, stop, step, body, loc) when counter = var ->
+        For (replacement, trav start, trav stop, trav step, trav body, loc)
+    | node ->
+        transform_children trav node
+
+let for_to_while node =
+    let rec traverse new_vars = function
+        | FunDef (export, ret_type, name, params, body, loc) ->
+            let new_vars = ref [] in
+            let body = traverse new_vars body in
+            let create_vardec name = VarDec (Int, name, None, noloc) in
+            let new_vardecs = List.map create_vardec !new_vars in
+            let _body = new_vardecs @ (flatten_blocks (block_body body)) in
+            FunDef (export, ret_type, name, params, Block _body, loc)
+
+        (* Transform for-loops to while-loops *)
+        | For (counter, start, stop, step, body, loc) ->
+            let _i = fresh_var counter in
+            let _stop = fresh_var "stop" in
+            let _step = fresh_var "step" in
+            new_vars := !new_vars @ [_i; _stop; _step];
+
+            let vi = Var (_i, noloc) in
+            let vstop = Var (_stop, locof stop) in
+            let vstep = Var (_step, locof step) in
+            let cond = Cond (
+                Binop (Gt, vstep, IntConst (0, noloc), noloc),
+                Binop (Lt, vi, vstop, noloc),
+                Binop (Gt, vi, vstop, noloc),
+                noloc
+            ) in
+            Block [
+                Assign (_i, None, start, locof start);
+                Assign (_stop, None, stop, locof stop);
+                Assign (_step, None, step, locof step);
+                While (cond, traverse new_vars (Block (
+                    (* TODO: check for illegal assigments of counter in body *)
+                    block_body (replace_var counter _i body) @
+                    [Assign (_i, None, Binop (Add, vi, vstep, noloc), noloc)]
+                )), loc);
+            ]
+
+        | node -> transform_children (traverse new_vars) node
+    in
+    traverse (ref []) node
+
 let rec array_init = function
-    (* transform scalar assignment into nested for loops *)
-    | Assign (name, None, ArrayScalar (value), loc) ->
+    (* Transform scalar assignment into nested for-loops *)
+    | Assign (name, None, ArrayScalar (value, ArrayDef (_, dims)), loc) ->
         let rec add_loop indices = function
             | [] ->
-                Assign (name, indices, value, loc)
+                Assign (name, Some indices, value, loc)
             | dim :: rest ->
-                let counter = fresh_var "counter" in
-                let ind = (indices @ [Var counter]) in
-                For (counter, IntConst 0, dim, IntConst 1, add_loop ind rest)
+                let counter = fresh_var "i" in
+                let body = Block [add_loop (indices @ [Var (counter, noloc)]) rest] in
+                For (counter, IntConst (0, noloc), dim, IntConst (1, noloc), body, noloc)
         in
         add_loop [] dims
 
-    | Assign (name, None, ArrayConst (dims), loc) -> Block []
+    (* TODO *)
+    | Assign (name, None, ArrayConst (dims, _), _) ->
+        Block []
 
-    | node -> transform array_init node
-*)
+    | node -> transform_children array_init node
 
 let rec phase input =
     prerr_endline "- Desugaring";
     match input with
     | Ast (node, args) ->
-        Ast (var_init node, args)
+        Ast (for_to_while (array_init (var_init node)), args)
     | _ -> raise (InvalidInput "desugar")

+ 31 - 17
phases/typecheck.ml

@@ -21,9 +21,15 @@ open Stringify
  *)
 
 let spec = function
-    | ArrayDec (ctype, dims) -> ArraySpec (ctype, list_size dims)
+    | ArrayDec (ctype, dims)
+    | ArrayDef (ctype, dims) -> ArraySpec (ctype, list_size dims)
     | ctype                  -> ctype
 
+let array_width = function
+    | ArrayDec (_, dims)
+    | ArrayDef (_, dims) -> list_size dims
+    | _                  -> raise InvalidNode
+
 let check_type ?(msg="") expected = function
     | Type (node, got) when (spec got) <> (spec expected) ->
         let msg = match msg with
@@ -57,16 +63,17 @@ let check_type_op allowed_types desc = function
     | Type _ -> ()
     | _ -> raise InvalidNode
 
-let array_width = function
-    | ArrayDec (_, dims) -> list_size dims
-    | _                  -> raise InvalidNode
+let check_dims_match dims dec_type errnode =
+    match (list_size dims, array_width dec_type) with
+    | (got, expected) when got != expected ->
+        let msg =
+            sprintf "dimension mismatch: expected %d indices, got %d"
+                    expected got
+        in
+        raise (NodeError (errnode, msg))
+    | _ -> ()
 
 let rec typecheck node = match node with
-    | BoolConst  (value, _) -> Type (node, Bool)
-    | IntConst   (value, _) -> Type (node, Int)
-    | FloatConst (value, _) -> Type (node, Float)
-    | VarUse (_, ctype, _)  -> Type (node, ctype)
-
     | FunUse (FunCall (_, args, _), FunDef (_, ftype, name, params, _, _), _) ->
         (match (list_size args, list_size params) with
         | (nargs, nparams) when nargs != nparams ->
@@ -110,14 +117,7 @@ let rec typecheck node = match node with
         node
     | VarLet (Assign (_, Some dims, (Type _ as value), _) as assign, dec_type, depth) ->
         (* Number of assigned indices must match array definition *)
-        (match (list_size dims, array_width dec_type) with
-        | (got, expected) when got != expected ->
-            let msg =
-                sprintf "dimension mismatch: expected %d indices, got %d"
-                        expected got
-            in
-            raise (NodeError (assign, msg))
-        | _ -> ());
+        check_dims_match dims dec_type assign;
 
         (* Array indices must be ints *)
         List.iter (check_type Int) dims;
@@ -134,6 +134,15 @@ let rec typecheck node = match node with
     | TypeCast (ctype, value, loc) ->
         typecheck (TypeCast (ctype, typecheck value, loc))
 
+    | VarUse (Deref (_, dims, _) as deref, dec_type, depth) ->
+        let dims = List.map typecheck dims in
+        List.iter (check_type Int) dims;
+
+        check_dims_match dims dec_type deref;
+
+        let ctype = base_type dec_type in
+        typecheck (VarUse (Type (deref, ctype), ctype, depth))
+
     | Allocate (name, dims, dec, loc) ->
         let dims = List.map typecheck dims in
         List.iter (check_type Int) dims;
@@ -185,6 +194,11 @@ let rec typecheck node = match node with
     | DoWhile (cond, body, loc) ->
         typecheck (DoWhile (typecheck cond, typecheck body, loc))
 
+    | BoolConst  (value, _) -> Type (node, Bool)
+    | IntConst   (value, _) -> Type (node, Int)
+    | FloatConst (value, _) -> Type (node, Float)
+    | VarUse (_, ctype, _)  -> Type (node, ctype)
+
     | _ -> transform_children typecheck node
 
 let rec phase input =

+ 3 - 2
stringify.ml

@@ -100,7 +100,8 @@ and node2str node =
     | IntConst (i, _) -> string_of_int i
     | FloatConst (f, _) -> string_of_float f
     | ArrayConst (dims, _) -> "[" ^ concat ", " dims ^ "]"
-    | ArrayScalar (value, _) -> str value
+    | ArrayScalar (value, _) -> "<scalar>(" ^ str value ^ ")"
+    (*| ArrayScalar (value, _) -> str value*)
     | Var (v, _) -> v
     | Deref (name, dims, _) -> name ^ (str (ArrayConst (dims, noloc)))
     | Monop (op, opnd, _) -> op2str op ^ str opnd
@@ -110,7 +111,7 @@ and node2str node =
     | FunCall (name, args, _) -> name ^ "(" ^ (concat ", " args) ^ ")"
     | Cond (cond, t, f, _) -> (str cond) ^ " ? " ^ str t ^ " : " ^ str f
 
-    (* FIXME: these should be printed when verbose=3
+    (* FIXME: these shoud be printen when verbose > 2
     | Arg node -> "<arg>(" ^ str node ^ ")"
     | Type (node, ctype) -> str node ^ ":" ^ type2str ctype
     | VarUse (value, _, _)

+ 3 - 1
util.ml

@@ -77,6 +77,8 @@ let transform_children trav node =
     | Deref (name, dims, loc) ->
         Deref (name, trav_all dims, loc)
 
+    | ArrayScalar (value, dims) ->
+        ArrayScalar (trav value, dims)
     | Type (value, ctype) ->
         Type (trav value, ctype)
     | VarLet (assign, def, depth) ->
@@ -117,7 +119,6 @@ let rec transform_all trav = function
     | IntConst (_, loc)
     | FloatConst (_, loc)
     | ArrayConst (_, loc)
-    | ArrayScalar (_, loc)
     | Var (_, loc)
     | Deref (_, _, loc)
     | Monop (_, _, loc)
@@ -126,6 +127,7 @@ let rec transform_all trav = function
     | TypeCast (_, _, loc)
     | FunCall (_, _, loc) -> loc
 
+    | ArrayScalar (value, _)
     | Expr value
     | VarLet (value, _, _)
     | VarUse (value, _, _)