Browse Source

Implemented array dimension reduction, generalized Array type, more bugfixes in other phases

Taddeus Kroes 12 years ago
parent
commit
1cdc10625d
11 changed files with 73 additions and 49 deletions
  1. 2 3
      ast.ml
  2. 6 6
      parser.mly
  3. 1 1
      phases/context_analysis.ml
  4. 7 11
      phases/desug.ml
  5. 23 1
      phases/dim_reduce.ml
  6. 6 9
      phases/expand_dims.ml
  7. 7 12
      phases/typecheck.ml
  8. 3 4
      stringify.ml
  9. 10 0
      test/dim_reduce.cvc
  10. 5 2
      util.ml
  11. 3 0
      util.mli

+ 2 - 3
ast.ml

@@ -6,9 +6,8 @@ type operator = Neg | Not
               | Eq | Ne | Lt | Le | Gt | Ge
               | Eq | Ne | Lt | Le | Gt | Ge
               | And | Or
               | And | Or
 type ctype = Void | Bool | Int | Float
 type ctype = Void | Bool | Int | Float
-           | ArrayDec of ctype * node list
-           | ArrayDef of ctype * node list
-           | ArraySpec of ctype * int
+           | Array of ctype * node list
+           | ArrayDepth of ctype * int
 and node =
 and node =
     (* global *)
     (* global *)
     | Program of node list * loc
     | Program of node list * loc

+ 6 - 6
parser.mly

@@ -81,7 +81,7 @@ decl:
       name=ID; SEMICOL
       name=ID; SEMICOL
     { let dimloc = loc $startpos(dims) $endpos(dims) in
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
       let loc = loc $startpos(name) $endpos(name) in
-      GlobalDec (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      GlobalDec (Array (ctype, make_dims dimloc dims), name, loc) }
 
 
     | export=boption(EXPORT); ctype=basic_type; name=ID; SEMICOL
     | export=boption(EXPORT); ctype=basic_type; name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
     { let loc = loc $startpos(name) $endpos(name) in
@@ -95,13 +95,13 @@ decl:
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, None, loc) }
+      GlobalDef (export, Array (ctype, dims), name, None, loc) }
 
 
     | export=boption(EXPORT); ctype=basic_type;
     | export=boption(EXPORT); ctype=basic_type;
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
     { let loc = loc $startpos(name) $endpos(name) in
-      GlobalDef (export, ArrayDef (ctype, dims), name, Some init, loc) }
+      GlobalDef (export, Array (ctype, dims), name, Some init, loc) }
 
 
 fun_header:
 fun_header:
     (* function header: use location of function name *)
     (* function header: use location of function name *)
@@ -119,7 +119,7 @@ param:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, ID); RBRACK; name=ID
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, ID); RBRACK; name=ID
     { let dimloc = loc $startpos(dims) $endpos(dims) in
     { let dimloc = loc $startpos(dims) $endpos(dims) in
       let loc = loc $startpos(name) $endpos(name) in
       let loc = loc $startpos(name) $endpos(name) in
-      Param (ArrayDec (ctype, make_dims dimloc dims), name, loc) }
+      Param (Array (ctype, make_dims dimloc dims), name, loc) }
 
 
 fun_body:
 fun_body:
     | var_dec* local_fun_dec* statement* loption(return_statement)
     | var_dec* local_fun_dec* statement* loption(return_statement)
@@ -147,12 +147,12 @@ var_dec:
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; SEMICOL
       name=ID; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, None, loc) }
+      VarDec (Array (ctype, dims), name, None, loc) }
 
 
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
     | ctype=basic_type; LBRACK; dims=separated_list(COMMA, expr); RBRACK;
       name=ID; ASSIGN; init=expr; SEMICOL
       name=ID; ASSIGN; init=expr; SEMICOL
     { let loc = loc $startpos(name) $endpos(name) in
     { let loc = loc $startpos(name) $endpos(name) in
-      VarDec (ArrayDef (ctype, dims), name, Some init, loc) }
+      VarDec (Array (ctype, dims), name, Some init, loc) }
 
 
 statement:
 statement:
     (* assignment: use location of assigned variable name *)
     (* assignment: use location of assigned variable name *)

+ 1 - 1
phases/context_analysis.ml

@@ -122,7 +122,7 @@ let rec analyse scope depth args node =
             let body = analyse local_scope (depth + 1) args body in
             let body = analyse local_scope (depth + 1) args body in
             FunDef (export, ret_type, name, params, body, loc)
             FunDef (export, ret_type, name, params, body, loc)
 
 
-        | Param (ArrayDec (_, dims), name, _) as node ->
+        | Param (Array (_, dims), name, _) as node ->
             let rec add_dims = function
             let rec add_dims = function
                 | [] -> ()
                 | [] -> ()
                 | Dim (name, _) as dim :: tail ->
                 | Dim (name, _) as dim :: tail ->

+ 7 - 11
phases/desug.ml

@@ -32,25 +32,21 @@ let rec var_init = function
             let rec trav inits node = match node with
             let rec trav inits node = match node with
                 (* translate scalar array initialisation to ArrayScalar node,
                 (* translate scalar array initialisation to ArrayScalar node,
                  * for easy replacement later on *)
                  * for easy replacement later on *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (BoolConst _  as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (FloatConst _ as v), loc) :: t
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (IntConst _   as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (BoolConst _  as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (FloatConst _ as v), loc) :: t
+                | VarDec (Array _ as vtype, name, Some (IntConst _   as v), loc) :: t ->
                     let init = Some (ArrayInit (ArrayScalar v, vtype)) in
                     let init = Some (ArrayInit (ArrayScalar v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
 
                 (* Wrap ArrayConst in ArrayInit to pass dimensions *)
                 (* Wrap ArrayConst in ArrayInit to pass dimensions *)
-                | VarDec (ArrayDef _ as vtype, name,
-                          Some (ArrayConst _ as v), loc) :: t ->
+                | VarDec (Array _ as vtype, name, Some (ArrayConst _ as v), loc) :: t ->
                     let init = Some (ArrayInit (v, vtype)) in
                     let init = Some (ArrayInit (v, vtype)) in
                     trav inits (VarDec (vtype, name, init, loc) :: t)
                     trav inits (VarDec (vtype, name, init, loc) :: t)
 
 
                 | VarDec (ctype, name, init, loc) as dec :: tl ->
                 | VarDec (ctype, name, init, loc) as dec :: tl ->
                     (* array definition: create __allocate statement *)
                     (* array definition: create __allocate statement *)
                     let alloc = match ctype with
                     let alloc = match ctype with
-                        | ArrayDef (_, dims) -> [Allocate (name, dims, dec, loc)]
+                        | Array (_, dims) -> [Allocate (name, dims, dec, loc)]
                         | _ -> []
                         | _ -> []
                     in
                     in
                     (* initialisation: create assign statement *)
                     (* initialisation: create assign statement *)
@@ -126,7 +122,7 @@ let for_to_while node =
 
 
 let rec array_init = function
 let rec array_init = function
     (* Transform scalar assignment into nested for-loops *)
     (* Transform scalar assignment into nested for-loops *)
-    | Assign (name, None, ArrayInit (ArrayScalar value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayScalar value, Array (_, dims)), loc) ->
         let rec add_loop indices = function
         let rec add_loop indices = function
             | [] ->
             | [] ->
                 Assign (name, Some indices, value, loc)
                 Assign (name, Some indices, value, loc)
@@ -140,7 +136,7 @@ let rec array_init = function
     (* Transform array constant inisialisation into separate assign statements
     (* Transform array constant inisialisation into separate assign statements
      * for all entries in the constant array *)
      * for all entries in the constant array *)
     (* TODO: only allow when array dimensions are constant? *)
     (* TODO: only allow when array dimensions are constant? *)
-    | Assign (name, None, ArrayInit (ArrayConst _ as value, ArrayDef (_, dims)), loc) ->
+    | Assign (name, None, ArrayInit (ArrayConst _ as value, Array (_, dims)), loc) ->
         let ndims = list_size dims in
         let ndims = list_size dims in
         let rec make_assigns depth i indices = function
         let rec make_assigns depth i indices = function
             | [] -> []
             | [] -> []

+ 23 - 1
phases/dim_reduce.ml

@@ -1,7 +1,29 @@
 open Ast
 open Ast
 open Util
 open Util
 
 
-let rec dim_reduce = function
+let rec multiply = function
+    | []       -> raise InvalidNode
+    | [node]   -> node
+    | hd :: tl -> Binop (Mul, hd, multiply tl, noloc)
+
+let rec expand dims = function
+    | []       -> raise InvalidNode
+    | [node]   -> dim_reduce node
+    | hd :: tl -> let mul = Binop (Mul, dim_reduce hd, (List.hd dims), noloc) in
+                  Binop (Mul, mul, expand (List.tl dims) tl, noloc)
+
+and dim_reduce = function
+    | Allocate (name, dims, dec, loc) ->
+        Allocate (name, [multiply dims], dec, loc)
+
+    | VarUse (Type (Deref (name, values, loc), _), (Array (_, dims) as ctype), depth) ->
+        let reduced = [expand (List.rev dims) values] in
+        VarUse (Deref (name, reduced, loc), ctype, depth)
+
+    | VarLet (Assign (name, Some values, value, loc), (Array (_, dims) as ctype), depth) ->
+        let reduced = Some [expand (List.rev dims) values] in
+        VarLet (Assign (name, reduced, dim_reduce value, loc), ctype, depth)
+
     | node -> transform_children dim_reduce node
     | node -> transform_children dim_reduce node
 
 
 let rec phase input =
 let rec phase input =

+ 6 - 9
phases/expand_dims.ml

@@ -18,10 +18,10 @@ let rec expand_dims = function
         FunCall (name, flatten_blocks (List.map expand_dims args), loc)
         FunCall (name, flatten_blocks (List.map expand_dims args), loc)
 
 
     (* Add additional parameters for array dimensions *)
     (* Add additional parameters for array dimensions *)
-    | Param (ArrayDec (ctype, dims), name, loc) ->
+    | Param (Array (_,dims) as ctype, name, loc) ->
         let rec do_expand = function
         let rec do_expand = function
             | [] ->
             | [] ->
-                [Param (ArraySpec (ctype, list_size dims), name, loc)]
+                [Param (ctype, name, loc)]
             | Dim (name, loc) :: tail ->
             | Dim (name, loc) :: tail ->
                 Param (Int, name, loc) :: (do_expand tail)
                 Param (Int, name, loc) :: (do_expand tail)
             | _ -> raise InvalidNode
             | _ -> raise InvalidNode
@@ -29,14 +29,11 @@ let rec expand_dims = function
         Block (do_expand dims)
         Block (do_expand dims)
 
 
     (* Add additional function arguments for array dimensions *)
     (* Add additional function arguments for array dimensions *)
-    | Arg (VarUse (var, ArrayDec (ctype, dims), depth)) ->
+    | Arg (VarUse (var, (Array (_, dims) as ctype), depth)) ->
         let rec do_expand = function
         let rec do_expand = function
-            | [] ->
-                let spec = ArraySpec (ctype, list_size dims) in
-                [Arg (VarUse (var, spec, depth))]
-            | Dim (name, _) :: tl ->
-                Arg (VarUse (Var (name, noloc), Int, depth)) :: (do_expand tl)
-            | _ -> raise InvalidNode
+            | []       -> [Arg (VarUse (var, ctype, depth))]
+            | hd :: tl -> Arg (VarUse (hd, Int, depth)) :: (do_expand tl)
+            | _        -> raise InvalidNode
         in
         in
         Block (do_expand dims)
         Block (do_expand dims)
 
 

+ 7 - 12
phases/typecheck.ml

@@ -21,14 +21,8 @@ open Stringify
  *)
  *)
 
 
 let spec = function
 let spec = function
-    | ArrayDec (ctype, dims)
-    | ArrayDef (ctype, dims) -> ArraySpec (ctype, list_size dims)
-    | ctype                  -> ctype
-
-let array_width = function
-    | ArrayDec (_, dims)
-    | ArrayDef (_, dims) -> list_size dims
-    | _                  -> raise InvalidNode
+    | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims)
+    | ctype               -> ctype
 
 
 let check_type ?(msg="") expected = function
 let check_type ?(msg="") expected = function
     | Type (node, got) when (spec got) <> (spec expected) ->
     | Type (node, got) when (spec got) <> (spec expected) ->
@@ -64,7 +58,7 @@ let check_type_op allowed_types desc = function
     | _ -> raise InvalidNode
     | _ -> raise InvalidNode
 
 
 let check_dims_match dims dec_type errnode =
 let check_dims_match dims dec_type errnode =
-    match (list_size dims, array_width dec_type) with
+    match (list_size dims, array_depth dec_type) with
     | (got, expected) when got != expected ->
     | (got, expected) when got != expected ->
         let msg = sprintf
         let msg = sprintf
             "dimension mismatch: expected %d indices, got %d" expected got
             "dimension mismatch: expected %d indices, got %d" expected got
@@ -139,8 +133,10 @@ let rec typecheck node = match node with
 
 
         check_dims_match dims dec_type deref;
         check_dims_match dims dec_type deref;
 
 
-        let ctype = base_type dec_type in
-        typecheck (VarUse (Type (deref, ctype), ctype, depth))
+        typecheck (VarUse (Type (deref, base_type dec_type), dec_type, depth))
+    | VarUse (Type (_, ctype), _, _)
+    | VarUse (_, ctype, _) ->
+        Type (node, ctype)
 
 
     | Allocate (name, dims, dec, loc) ->
     | Allocate (name, dims, dec, loc) ->
         let dims = List.map typecheck dims in
         let dims = List.map typecheck dims in
@@ -196,7 +192,6 @@ let rec typecheck node = match node with
     | BoolConst  (value, _) -> Type (node, Bool)
     | BoolConst  (value, _) -> Type (node, Bool)
     | IntConst   (value, _) -> Type (node, Int)
     | IntConst   (value, _) -> Type (node, Int)
     | FloatConst (value, _) -> Type (node, Float)
     | FloatConst (value, _) -> Type (node, Float)
-    | VarUse (_, ctype, _)  -> Type (node, ctype)
 
 
     | _ -> transform_children typecheck node
     | _ -> transform_children typecheck node
 
 

+ 3 - 4
stringify.ml

@@ -29,9 +29,8 @@ let rec type2str = function
     | Bool  -> "bool"
     | Bool  -> "bool"
     | Int   -> "int"
     | Int   -> "int"
     | Float -> "float"
     | Float -> "float"
-    | ArrayDec (t, dims)
-    | ArrayDef (t, dims)   -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
-    | ArraySpec (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
+    | Array (t, dims)       -> (type2str t) ^ "[" ^ (concat ", " dims) ^ "]"
+    | ArrayDepth (t, ndims) -> (type2str t) ^ "[" ^ string_of_int ndims ^ "]"
 
 
 and concat sep nodes = String.concat sep (List.map node2str nodes)
 and concat sep nodes = String.concat sep (List.map node2str nodes)
 
 
@@ -114,7 +113,7 @@ and node2str node =
     (* FIXME: these shoud be printen when verbose > 2
     (* FIXME: these shoud be printen when verbose > 2
     | Arg node -> "<arg>(" ^ str node ^ ")"
     | Arg node -> "<arg>(" ^ str node ^ ")"
     | Type (node, ctype) -> str node ^ ":" ^ type2str ctype
     | Type (node, ctype) -> str node ^ ":" ^ type2str ctype
-    | VarUse (value, _, _)
+    | VarUse (value, ctype, _) -> "<use:" ^ type2str ctype ^ ">(" ^ str value ^ ")"
     | FunUse (value, _, _) -> "<use>(" ^ str value ^ ")"
     | FunUse (value, _, _) -> "<use>(" ^ str value ^ ")"
     *)
     *)
 
 

+ 10 - 0
test/dim_reduce.cvc

@@ -0,0 +1,10 @@
+void foo(int[m, n] a) {
+    int[n, m] b;
+
+    for (int i = 0, n) {
+        for (int j = 0, m)
+            b[i, j] = a[j, i];
+    }
+
+    foo(b);
+}

+ 5 - 2
util.ml

@@ -207,6 +207,9 @@ let rec list_size = function
     | hd :: tl -> 1 + (list_size tl)
     | hd :: tl -> 1 + (list_size tl)
 
 
 let base_type = function
 let base_type = function
-    | ArrayDec (ctype, _)
-    | ArrayDef (ctype, _)
+    | Array (ctype, _)
     | ctype -> ctype
     | ctype -> ctype
+
+let array_depth = function
+    | Array (_, dims) -> list_size dims
+    | _               -> raise InvalidNode

+ 3 - 0
util.mli

@@ -35,3 +35,6 @@ val list_size : 'a list -> int
 
 
 (* Get the basic type of a ctype, removing array dimensions *)
 (* Get the basic type of a ctype, removing array dimensions *)
 val base_type : Ast.ctype -> Ast.ctype
 val base_type : Ast.ctype -> Ast.ctype
+
+(* Get the number of dimensions from an Array type *)
+val array_depth : Ast.ctype -> int