Parcourir la source

Implemented array dimension reduction, generalized Array type, more bugfixes in other phases

Taddeus Kroes il y a 12 ans
Parent
commit
1cdc10625d
11 fichiers modifiés avec 73 ajouts et 49 suppressions
  1. 2 3
      ast.ml
  2. 6 6
      parser.mly
  3. 1 1
      phases/context_analysis.ml
  4. 7 11
      phases/desug.ml
  5. 23 1
      phases/dim_reduce.ml
  6. 6 9
      phases/expand_dims.ml
  7. 7 12
      phases/typecheck.ml
  8. 3 4
      stringify.ml
  9. 10 0
      test/dim_reduce.cvc
  10. 5 2
      util.ml
  11. 3 0
      util.mli

+ 2 - 3
ast.ml

@@ -6,9 +6,8 @@ type operator = Neg | Not
               | Eq | Ne | Lt | Le | Gt | Ge
               | And | Or
 type ctype = Void | Bool | Int | Float
-           | ArrayDec of ctype * node list
-           | ArrayDef of ctype * node list
-           | ArraySpec of ctype * int
+           | Array of ctype * node list
+           | ArrayDepth of ctype * int
 and node =
     (* global *)
     | Program of node list * loc

+ 6 - 6
parser.mly

@@ -81,7 +81,7 @@ decl:
       name=ID; SEMICOL
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
-      GlobalDec (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      GlobalDec (Array (ctype, make_dims dimloc dims), name, loc) }
 
     | export=boption(EXPORT); ctype=basic_type; name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
@@ -95,13 +95,13 @@ decl:
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, None, loc) }
+      GlobalDef (export, Array (ctype, dims), name, None, loc) }
 
     | export=boption(EXPORT); ctype=basic_type;
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, Some init, loc) }
+      GlobalDef (export, Array (ctype, dims), name, Some init, loc) }
 
 fun_header:
     (* function header: use location of function name *)
@@ -119,7 +119,7 @@ param:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, ID); RBRACK; name=ID
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
-      Param (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      Param (Array (ctype, make_dims dimloc dims), name, loc) }
 
 fun_body:
     | var_dec* local_fun_dec* statement* loption(return_statement)
@@ -147,12 +147,12 @@ var_dec:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, None, loc) }
+      VarDec (Array (ctype, dims), name, None, loc) }
 
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, Some init, loc) }
+      VarDec (Array (ctype, dims), name, Some init, loc) }
 
 statement:
     (* assignment: use location of assigned variable name *)

+ 1 - 1
phases/context_analysis.ml

@@ -122,7 +122,7 @@ let rec analyse scope depth args node =
             let body = analyse local_scope (depth + 1) args body in
             FunDef (export, ret_type, name, params, body, loc)
 
-        | Param (ArrayDec (_, dims), name, _) as node ->
+        | Param (Array (_, dims), name, _) as node ->
             let rec add_dims = function
                 | [] -> ()
                 | Dim (name, _) as dim :: tail ->

+ 7 - 11
phases/desug.ml

@@ -32,25 +32,21 @@ let rec var_init = function
             let rec trav inits node = match node with
                 (* translate scalar array initialisation to ArrayScalar node,
                  * for easy replacement later on *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (BoolConst _  as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (FloatConst _ as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (IntConst _   as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (BoolConst _  as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (FloatConst _ as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (IntConst _   as v), loc) :: t ->
                     let init = Some (ArrayInit (ArrayScalar v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
                 (* Wrap ArrayConst in ArrayInit to pass dimensions *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (ArrayConst _ as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (ArrayConst _ as v), loc) :: t ->
                     let init = Some (ArrayInit (v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
                 | VarDec (ctype, name, init, loc) as dec :: tl ->
                     (* array definition: create __allocate statement *)
                     let alloc = match ctype with
-                        | ArrayDef (_, dims) -> [Allocate (name, dims, dec, loc)]
+                        | Array (_, dims) -> [Allocate (name, dims, dec, loc)]
                         | _ -> []
                     in
                     (* initialisation: create assign statement *)
@@ -126,7 +122,7 @@ let for_to_while node =
 
 let rec array_init = function
     (* Transform scalar assignment into nested for-loops *)
-    | Assign (name, None, ArrayInit (ArrayScalar value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayScalar value, Array (_, dims)), loc) ->
         let rec add_loop indices = function
             | [] ->
                 Assign (name, Some indices, value, loc)
@@ -140,7 +136,7 @@ let rec array_init = function
     (* Transform array constant inisialisation into separate assign statements
      * for all entries in the constant array *)
     (* TODO: only allow when array dimensions are constant? *)
-    | Assign (name, None, ArrayInit (ArrayConst _ as value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayConst _ as value, Array (_, dims)), loc) ->
         let ndims = list_size dims in
         let rec make_assigns depth i indices = function
             | [] -> []

+ 23 - 1
phases/dim_reduce.ml

@@ -1,7 +1,29 @@
 open Ast
 open Util
 
-let rec dim_reduce = function
+let rec multiply = function
+    | []       -> raise InvalidNode
+    | [node]   -> node
+    | hd :: tl -> Binop (Mul, hd, multiply tl, noloc)
+
+let rec expand dims = function
+    | []       -> raise InvalidNode
+    | [node]   -> dim_reduce node
+    | hd :: tl -> let mul = Binop (Mul, dim_reduce hd, (List.hd dims), noloc) in
+                  Binop (Mul, mul, expand (List.tl dims) tl, noloc)
+
+and dim_reduce = function
+    | Allocate (name, dims, dec, loc) ->
+        Allocate (name, [multiply dims], dec, loc)
+
+    | VarUse (Type (Deref (name, values, loc), _), (Array (_, dims) as ctype), depth) ->
+        let reduced = [expand (List.rev dims) values] in
+        VarUse (Deref (name, reduced, loc), ctype, depth)
+
+    | VarLet (Assign (name, Some values, value, loc), (Array (_, dims) as ctype), depth) ->
+        let reduced = Some [expand (List.rev dims) values] in
+        VarLet (Assign (name, reduced, dim_reduce value, loc), ctype, depth)
+
     | node -> transform_children dim_reduce node
 
 let rec phase input =

+ 6 - 9
phases/expand_dims.ml

@@ -18,10 +18,10 @@ let rec expand_dims = function
         FunCall (name, flatten_blocks (List.map expand_dims args), loc)
 
     (* Add additional parameters for array dimensions *)
-    | Param (ArrayDec (ctype, dims), name, loc) ->
+    | Param (Array (_,dims) as ctype, name, loc) ->
         let rec do_expand = function
             | [] ->
-                [Param (ArraySpec (ctype, list_size dims), name, loc)]
+                [Param (ctype, name, loc)]
             | Dim (name, loc) :: tail ->
                 Param (Int, name, loc) :: (do_expand tail)
             | _ -> raise InvalidNode
@@ -29,14 +29,11 @@ let rec expand_dims = function
         Block (do_expand dims)
 
     (* Add additional function arguments for array dimensions *)
-    | Arg (VarUse (var, ArrayDec (ctype, dims), depth)) ->
+    | Arg (VarUse (var, (Array (_, dims) as ctype), depth)) ->
         let rec do_expand = function
-            | [] ->
-                let spec = ArraySpec (ctype, list_size dims) in
-                [Arg (VarUse (var, spec, depth))]
-            | Dim (name, _) :: tl ->
-                Arg (VarUse (Var (name, noloc), Int, depth)) :: (do_expand tl)
-            | _ -> raise InvalidNode
+            | []       -> [Arg (VarUse (var, ctype, depth))]
+            | hd :: tl -> Arg (VarUse (hd, Int, depth)) :: (do_expand tl)
+            | _        -> raise InvalidNode
         in
         Block (do_expand dims)
 

+ 7 - 12
phases/typecheck.ml

@@ -21,14 +21,8 @@ open Stringify
  *)
 
 let spec = function
-    | ArrayDec (ctype, dims)
-    | ArrayDef (ctype, dims) -> ArraySpec (ctype, list_size dims)
-    | ctype                  -> ctype
-
-let array_width = function
-    | ArrayDec (_, dims)
-    | ArrayDef (_, dims) -> list_size dims
-    | _                  -> raise InvalidNode
+    | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims)
+    | ctype               -> ctype
 
 let check_type ?(msg="") expected = function
     | Type (node, got) when (spec got) <> (spec expected) ->
@@ -64,7 +58,7 @@ let check_type_op allowed_types desc = function
     | _ -> raise InvalidNode
 
 let check_dims_match dims dec_type errnode =
-    match (list_size dims, array_width dec_type) with
+    match (list_size dims, array_depth dec_type) with
     | (got, expected) when got != expected ->
         let msg = sprintf
             "dimension mismatch: expected %d indices, got %d" expected got
@@ -139,8 +133,10 @@ let rec typecheck node = match node with
 
         check_dims_match dims dec_type deref;
 
-        let ctype = base_type dec_type in
-        typecheck (VarUse (Type (deref, ctype), ctype, depth))
+        typecheck (VarUse (Type (deref, base_type dec_type), dec_type, depth))
+    | VarUse (Type (_, ctype), _, _)
+    | VarUse (_, ctype, _) ->
+        Type (node, ctype)
 
     | Allocate (name, dims, dec, loc) ->
         let dims = List.map typecheck dims in
@@ -196,7 +192,6 @@ let rec typecheck node = match node with
     | BoolConst  (value, _) -> Type (node, Bool)
     | IntConst   (value, _) -> Type (node, Int)
     | FloatConst (value, _) -> Type (node, Float)
-    | VarUse (_, ctype, _)  -> Type (node, ctype)
 
     | _ -> transform_children typecheck node
 

+ 3 - 4
stringify.ml

@@ -29,9 +29,8 @@ let rec type2str = function
     | Bool  -> "bool"
     | Int   -> "int"
     | Float -> "float"
-    | ArrayDec (t, dims)
-    | ArrayDef (t, dims)   -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
-    | ArraySpec (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
+    | Array (t, dims)       -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
+    | ArrayDepth (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
 
 and concat sep nodes = String.concat sep (List.map node2str nodes)
 
@@ -114,7 +113,7 @@ and node2str node =
     (* FIXME: these shoud be printen when verbose > 2
     | Arg node -> "<arg>(" ^ str node ^ ")"
     | Type (node, ctype) -> str node ^ ":" ^ type2str ctype
-    | VarUse (value, _, _)
+    | VarUse (value, ctype, _) -> "<use:" ^ type2str ctype ^ ">(" ^ str value ^ ")"
     | FunUse (value, _, _) -> "<use>(" ^ str value ^ ")"
     *)
 

+ 10 - 0
test/dim_reduce.cvc

@@ -0,0 +1,10 @@
+void foo(int[m, n] a) {
+    int[n, m] b;
+
+    for (int i = 0, n) {
+        for (int j = 0, m)
+            b[i, j] = a[j, i];
+    }
+
+    foo(b);
+}

+ 5 - 2
util.ml

@@ -207,6 +207,9 @@ let rec list_size = function
     | hd :: tl -> 1 + (list_size tl)
 
 let base_type = function
-    | ArrayDec (ctype, _)
-    | ArrayDef (ctype, _)
+    | Array (ctype, _)
     | ctype -> ctype
+
+let array_depth = function
+    | Array (_, dims) -> list_size dims
+    | _               -> raise InvalidNode

+ 3 - 0
util.mli

@@ -35,3 +35,6 @@ val list_size : 'a list -> int
 
 (* Get the basic type of a ctype, removing array dimensions *)
 val base_type : Ast.ctype -> Ast.ctype
+
+(* Get the number of dimensions from an Array type *)
+val array_depth : Ast.ctype -> int