| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- open Types
- open Util
- let flatten_type = function
- | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
- GlobalDef (export, Array ctype, name, None, ann)
- | VarDec (ArrayDims (ctype, _), name, None, ann) ->
- VarDec (Array ctype, name, None, ann)
- | Param (ArrayDims (ctype, _), name, ann) ->
- Param (Array ctype, name, ann)
- | _ -> raise InvalidNode
- (* Pass array dimensions explicitly to functions *)
- let rec expand_dims = function
- (* Flatten Block nodes returned by transformations below *)
- | FunDef (export, ret_type, name, params, body, ann) ->
- let params = flatten_blocks (List.map expand_dims params) in
- FunDef (export, ret_type, name, params, expand_dims body, ann)
- | FunDec (ret_type, name, params, ann) ->
- let params = flatten_blocks (List.map expand_dims params) in
- FunDec (ret_type, name, params, ann)
- | FunUse (dec, params, ann) as node ->
- FunUse (dec, flatten_blocks (List.map expand_dims params), ann)
- (* Add additional parameters for array dimensions *)
- | Param (ArrayDims (ctype, dims), name, ann) ->
- let rec do_expand = function
- | [] -> [Param (Array ctype, name, ann)]
- | Dim (name, ann) :: tail ->
- Param (Int, name, ann) :: (do_expand tail)
- | _ -> raise InvalidNode
- in
- Block (do_expand dims)
- (* Add additional function arguments for array dimensions *)
- | Arg (VarUse (dec, None, ann)) when is_array dec ->
- let make_dimdec = function
- | Dim (name, ann) -> Param (Int, name, ann)
- | _ -> raise InvalidNode
- in
- let rec do_expand = function
- | [] ->
- (* Remove the (now obsolete dimensions from the type) *)
- [VarUse (flatten_type dec, None, ann)]
- | hd :: tl ->
- (* A declaration has been added for each dimension during earlier
- * phases, so we can safely reconstruct it here *)
- Arg (VarUse (make_dimdec hd, None, [])) :: (do_expand tl)
- in
- let dims = match typeof dec with
- | ArrayDims (_, dims) -> dims
- | _ -> raise InvalidNode
- in
- Block (do_expand dims)
- | node -> traverse_unit expand_dims node
- let rec multiply = function
- | [] -> raise InvalidNode
- | [node] -> node
- | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
- let rec multiply_all = function
- | [] -> raise InvalidNode
- | [node] -> node
- | hd :: tl -> Binop (Mul, hd, multiply_all tl, [])
- let rec expand depth dims =
- let rec do_expand dims = function
- | [] -> raise InvalidNode
- | [node] -> dim_reduce depth node
- | i :: j :: tl ->
- let parent_width = List.hd dims in
- let mul = Binop (Mul, dim_reduce depth i, parent_width, [Type Int]) in
- do_expand (List.tl dims) (Binop (Add, mul, j, [Type Int]) :: tl)
- in
- let use_dim = function
- | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
- | node -> node
- in
- do_expand (List.map use_dim (List.tl dims))
- (* Transform multi-dimensional arrays into one-dimensional arrays in row-major
- * order *)
- and dim_reduce depth = function
- | Allocate (dec, dims, ann) ->
- Allocate (dec, [multiply dims], ann)
- (* Simplify array types in declarations *)
- | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
- GlobalDef (export, Array ctype, name, None, ann)
- | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
- GlobalDef (export, Array ctype, name, None, ann)
- | VarDec (ArrayDims (ctype, _), name, None, ann) ->
- VarDec (Array ctype, name, None, ann)
- (* Increase nesting depth when goiing into function *)
- | FunDef (export, ret_type, name, params, body, ann) ->
- let trav = dim_reduce (depth + 1) in
- FunDef (export, ret_type, name, List.map trav params, trav body, ann)
- (* Expand indices when dereferencing *)
- | VarUse (dec, Some values, ann) ->
- begin match typeof dec with
- | ArrayDims (_, dims) ->
- VarUse (dec, Some [expand depth dims values], ann)
- | _ -> raise InvalidNode
- end
- (* Expand indices when assigning to array index *)
- | VarLet (dec, Some values, value, ann) ->
- begin match typeof dec with
- | ArrayDims (_, dims) ->
- let value = dim_reduce depth value in
- VarLet (dec, Some [expand depth dims values], value, ann)
- | _ -> raise InvalidNode
- end
- | node -> traverse_unit (dim_reduce depth) node
- let phase = function
- | Ast node -> Ast (dim_reduce 0 (expand_dims node))
- | _ -> raise (InvalidInput "dimension reduction")
|