Explorar el Código

Implemented array dimension reduction, generalized Array type, more bugfixes in other phases

Taddeus Kroes hace 12 años
padre
commit
1cdc10625d
Se han modificado 11 ficheros con 73 adiciones y 49 borrados
  1. 2 3
      ast.ml
  2. 6 6
      parser.mly
  3. 1 1
      phases/context_analysis.ml
  4. 7 11
      phases/desug.ml
  5. 23 1
      phases/dim_reduce.ml
  6. 6 9
      phases/expand_dims.ml
  7. 7 12
      phases/typecheck.ml
  8. 3 4
      stringify.ml
  9. 10 0
      test/dim_reduce.cvc
  10. 5 2
      util.ml
  11. 3 0
      util.mli

+ 2 - 3
ast.ml

@@ -6,9 +6,8 @@ type operator = Neg | Not
               | Eq | Ne | Lt | Le | Gt | Ge
               | And | Or
 type ctype = Void | Bool | Int | Float
-           | ArrayDec of ctype * node list
-           | ArrayDef of ctype * node list
-           | ArraySpec of ctype * int
+           | Array of ctype * node list
+           | ArrayDepth of ctype * int
 and node =
     (* global *)
     | Program of node list * loc

+ 6 - 6
parser.mly

@@ -81,7 +81,7 @@ decl:
       name=ID; SEMICOL
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
-      GlobalDec (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      GlobalDec (Array (ctype, make_dims dimloc dims), name, loc) }
 
     | export=boption(EXPORT); ctype=basic_type; name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
@@ -95,13 +95,13 @@ decl:
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, None, loc) }
+      GlobalDef (export, Array (ctype, dims), name, None, loc) }
 
     | export=boption(EXPORT); ctype=basic_type;
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, Some init, loc) }
+      GlobalDef (export, Array (ctype, dims), name, Some init, loc) }
 
 fun_header:
     (* function header: use location of function name *)
@@ -119,7 +119,7 @@ param:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, ID); RBRACK; name=ID
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
-      Param (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      Param (Array (ctype, make_dims dimloc dims), name, loc) }
 
 fun_body:
     | var_dec* local_fun_dec* statement* loption(return_statement)
@@ -147,12 +147,12 @@ var_dec:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, None, loc) }
+      VarDec (Array (ctype, dims), name, None, loc) }
 
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, Some init, loc) }
+      VarDec (Array (ctype, dims), name, Some init, loc) }
 
 statement:
     (* assignment: use location of assigned variable name *)

+ 1 - 1
phases/context_analysis.ml

@@ -122,7 +122,7 @@ let rec analyse scope depth args node =
             let body = analyse local_scope (depth + 1) args body in
             FunDef (export, ret_type, name, params, body, loc)
 
-        | Param (ArrayDec (_, dims), name, _) as node ->
+        | Param (Array (_, dims), name, _) as node ->
             let rec add_dims = function
                 | [] -> ()
                 | Dim (name, _) as dim :: tail ->

+ 7 - 11
phases/desug.ml

@@ -32,25 +32,21 @@ let rec var_init = function
             let rec trav inits node = match node with
                 (* translate scalar array initialisation to ArrayScalar node,
                  * for easy replacement later on *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (BoolConst _  as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (FloatConst _ as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (IntConst _   as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (BoolConst _  as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (FloatConst _ as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (IntConst _   as v), loc) :: t ->
                     let init = Some (ArrayInit (ArrayScalar v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
                 (* Wrap ArrayConst in ArrayInit to pass dimensions *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (ArrayConst _ as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (ArrayConst _ as v), loc) :: t ->
                     let init = Some (ArrayInit (v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
                 | VarDec (ctype, name, init, loc) as dec :: tl ->
                     (* array definition: create __allocate statement *)
                     let alloc = match ctype with
-                        | ArrayDef (_, dims) -> [Allocate (name, dims, dec, loc)]
+                        | Array (_, dims) -> [Allocate (name, dims, dec, loc)]
                         | _ -> []
                     in
                     (* initialisation: create assign statement *)
@@ -126,7 +122,7 @@ let for_to_while node =
 
 let rec array_init = function
     (* Transform scalar assignment into nested for-loops *)
-    | Assign (name, None, ArrayInit (ArrayScalar value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayScalar value, Array (_, dims)), loc) ->
         let rec add_loop indices = function
             | [] ->
                 Assign (name, Some indices, value, loc)
@@ -140,7 +136,7 @@ let rec array_init = function
     (* Transform array constant inisialisation into separate assign statements
      * for all entries in the constant array *)
     (* TODO: only allow when array dimensions are constant? *)
-    | Assign (name, None, ArrayInit (ArrayConst _ as value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayConst _ as value, Array (_, dims)), loc) ->
         let ndims = list_size dims in
         let rec make_assigns depth i indices = function
             | [] -> []

+ 23 - 1
phases/dim_reduce.ml

@@ -1,7 +1,29 @@
 open Ast
 open Util
 
-let rec dim_reduce = function
+let rec multiply = function
+    | []       -> raise InvalidNode
+    | [node]   -> node
+    | hd :: tl -> Binop (Mul, hd, multiply tl, noloc)
+
+let rec expand dims = function
+    | []       -> raise InvalidNode
+    | [node]   -> dim_reduce node
+    | hd :: tl -> let mul = Binop (Mul, dim_reduce hd, (List.hd dims), noloc) in
+                  Binop (Mul, mul, expand (List.tl dims) tl, noloc)
+
+and dim_reduce = function
+    | Allocate (name, dims, dec, loc) ->
+        Allocate (name, [multiply dims], dec, loc)
+
+    | VarUse (Type (Deref (name, values, loc), _), (Array (_, dims) as ctype), depth) ->
+        let reduced = [expand (List.rev dims) values] in
+        VarUse (Deref (name, reduced, loc), ctype, depth)
+
+    | VarLet (Assign (name, Some values, value, loc), (Array (_, dims) as ctype), depth) ->
+        let reduced = Some [expand (List.rev dims) values] in
+        VarLet (Assign (name, reduced, dim_reduce value, loc), ctype, depth)
+
     | node -> transform_children dim_reduce node
 
 let rec phase input =

+ 6 - 9
phases/expand_dims.ml

@@ -18,10 +18,10 @@ let rec expand_dims = function
         FunCall (name, flatten_blocks (List.map expand_dims args), loc)
 
     (* Add additional parameters for array dimensions *)
-    | Param (ArrayDec (ctype, dims), name, loc) ->
+    | Param (Array (_,dims) as ctype, name, loc) ->
         let rec do_expand = function
             | [] ->
-                [Param (ArraySpec (ctype, list_size dims), name, loc)]
+                [Param (ctype, name, loc)]
             | Dim (name, loc) :: tail ->
                 Param (Int, name, loc) :: (do_expand tail)
             | _ -> raise InvalidNode
@@ -29,14 +29,11 @@ let rec expand_dims = function
         Block (do_expand dims)
 
     (* Add additional function arguments for array dimensions *)
-    | Arg (VarUse (var, ArrayDec (ctype, dims), depth)) ->
+    | Arg (VarUse (var, (Array (_, dims) as ctype), depth)) ->
         let rec do_expand = function
-            | [] ->
-                let spec = ArraySpec (ctype, list_size dims) in
-                [Arg (VarUse (var, spec, depth))]
-            | Dim (name, _) :: tl ->
-                Arg (VarUse (Var (name, noloc), Int, depth)) :: (do_expand tl)
-            | _ -> raise InvalidNode
+            | []       -> [Arg (VarUse (var, ctype, depth))]
+            | hd :: tl -> Arg (VarUse (hd, Int, depth)) :: (do_expand tl)
+            | _        -> raise InvalidNode
         in
         Block (do_expand dims)
 

+ 7 - 12
phases/typecheck.ml

@@ -21,14 +21,8 @@ open Stringify
  *)
 
 let spec = function
-    | ArrayDec (ctype, dims)
-    | ArrayDef (ctype, dims) -> ArraySpec (ctype, list_size dims)
-    | ctype                  -> ctype
-
-let array_width = function
-    | ArrayDec (_, dims)
-    | ArrayDef (_, dims) -> list_size dims
-    | _                  -> raise InvalidNode
+    | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims)
+    | ctype               -> ctype
 
 let check_type ?(msg="") expected = function
     | Type (node, got) when (spec got) <> (spec expected) ->
@@ -64,7 +58,7 @@ let check_type_op allowed_types desc = function
     | _ -> raise InvalidNode
 
 let check_dims_match dims dec_type errnode =
-    match (list_size dims, array_width dec_type) with
+    match (list_size dims, array_depth dec_type) with
     | (got, expected) when got != expected ->
         let msg = sprintf
             "dimension mismatch: expected %d indices, got %d" expected got
@@ -139,8 +133,10 @@ let rec typecheck node = match node with
 
         check_dims_match dims dec_type deref;
 
-        let ctype = base_type dec_type in
-        typecheck (VarUse (Type (deref, ctype), ctype, depth))
+        typecheck (VarUse (Type (deref, base_type dec_type), dec_type, depth))
+    | VarUse (Type (_, ctype), _, _)
+    | VarUse (_, ctype, _) ->
+        Type (node, ctype)
 
     | Allocate (name, dims, dec, loc) ->
         let dims = List.map typecheck dims in
@@ -196,7 +192,6 @@ let rec typecheck node = match node with
     | BoolConst  (value, _) -> Type (node, Bool)
     | IntConst   (value, _) -> Type (node, Int)
     | FloatConst (value, _) -> Type (node, Float)
-    | VarUse (_, ctype, _)  -> Type (node, ctype)
 
     | _ -> transform_children typecheck node
 

+ 3 - 4
stringify.ml

@@ -29,9 +29,8 @@ let rec type2str = function
     | Bool  -> "bool"
     | Int   -> "int"
     | Float -> "float"
-    | ArrayDec (t, dims)
-    | ArrayDef (t, dims)   -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
-    | ArraySpec (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
+    | Array (t, dims)       -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
+    | ArrayDepth (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
 
 and concat sep nodes = String.concat sep (List.map node2str nodes)
 
@@ -114,7 +113,7 @@ and node2str node =
     (* FIXME: these shoud be printen when verbose > 2
     | Arg node -> "<arg>(" ^ str node ^ ")"
     | Type (node, ctype) -> str node ^ ":" ^ type2str ctype
-    | VarUse (value, _, _)
+    | VarUse (value, ctype, _) -> "<use:" ^ type2str ctype ^ ">(" ^ str value ^ ")"
     | FunUse (value, _, _) -> "<use>(" ^ str value ^ ")"
     *)
 

+ 10 - 0
test/dim_reduce.cvc

@@ -0,0 +1,10 @@
+void foo(int[m, n] a) {
+    int[n, m] b;
+
+    for (int i = 0, n) {
+        for (int j = 0, m)
+            b[i, j] = a[j, i];
+    }
+
+    foo(b);
+}

+ 5 - 2
util.ml

@@ -207,6 +207,9 @@ let rec list_size = function
     | hd :: tl -> 1 + (list_size tl)
 
 let base_type = function
-    | ArrayDec (ctype, _)
-    | ArrayDef (ctype, _)
+    | Array (ctype, _)
     | ctype -> ctype
+
+let array_depth = function
+    | Array (_, dims) -> list_size dims
+    | _               -> raise InvalidNode

+ 3 - 0
util.mli

@@ -35,3 +35,6 @@ val list_size : 'a list -> int
 
 (* Get the basic type of a ctype, removing array dimensions *)
 val base_type : Ast.ctype -> Ast.ctype
+
+(* Get the number of dimensions from an Array type *)
+val array_depth : Ast.ctype -> int