dimreduce.ml 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. open Types
  2. open Util
  3. let flatten_type = function
  4. | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
  5. GlobalDef (export, Array ctype, name, None, ann)
  6. | GlobalDec (ArrayDims (ctype, _), name, ann) ->
  7. GlobalDec (Array ctype, name, ann)
  8. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  9. VarDec (Array ctype, name, None, ann)
  10. | Param (ArrayDims (ctype, _), name, ann) ->
  11. Param (Array ctype, name, ann)
  12. | _ -> raise InvalidNode
  13. (* Pass array dimensions explicitly to functions *)
  14. let rec expand_dims = function
  15. (* For extern arrays, also add a new variables for each dimension with a
  16. * consistent naming scheme so that they can be exported with the same name by
  17. * another module *)
  18. | GlobalDec (ArrayDims (_, dims) as ctype, name, ann) ->
  19. let rec gendims decs = function
  20. | [] -> decs
  21. | Dim (name, ann) as dim :: tl ->
  22. gendims (GlobalDec (Int, name, ann) :: decs) tl
  23. | _ -> raise InvalidNode
  24. in
  25. Block (List.rev (GlobalDec (ctype, name, ann) :: gendims [] dims))
  26. (* Flatten Block nodes returned by transformations below *)
  27. | FunDef (export, ret_type, name, params, body, ann) ->
  28. let params = flatten_blocks (List.map expand_dims params) in
  29. FunDef (export, ret_type, name, params, expand_dims body, ann)
  30. | FunDec (ret_type, name, params, ann) ->
  31. let params = flatten_blocks (List.map expand_dims params) in
  32. FunDec (ret_type, name, params, ann)
  33. | FunUse (dec, params, ann) ->
  34. FunUse (dec, flatten_blocks (List.map expand_dims params), ann)
  35. (* Add additional parameters for array dimensions *)
  36. | Param (ArrayDims (_, dims) as ctype, name, ann) ->
  37. let rec do_expand = function
  38. | [] -> [Param (ctype, name, ann)]
  39. | Dim (name, ann) :: tail ->
  40. Param (Int, name, ann) :: (do_expand tail)
  41. | _ -> raise InvalidNode
  42. in
  43. Block (do_expand dims)
  44. (* Add additional function arguments for array dimensions *)
  45. | Arg (VarUse (dec, None, ann)) when is_array dec ->
  46. let make_dimdec = function
  47. | Dim (name, ann) -> Param (Int, name, ann)
  48. | _ -> raise InvalidNode
  49. in
  50. let rec do_expand = function
  51. | [] ->
  52. (* Remove the (now obsolete dimensions from the type) *)
  53. [VarUse (flatten_type dec, None, ann)]
  54. | hd :: tl ->
  55. (* A declaration has been added for each dimension during earlier
  56. * phases, so we can safely reconstruct it here *)
  57. Arg (VarUse (make_dimdec hd, None, [])) :: do_expand tl
  58. in
  59. let dims =
  60. match typeof dec with
  61. | ArrayDims (_, dims) -> dims
  62. | _ -> raise InvalidNode
  63. in
  64. Block (do_expand dims)
  65. | node -> traverse_unit expand_dims node
  66. let rec multiply = function
  67. | [] -> raise InvalidNode
  68. | [node] -> node
  69. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  70. let rec expand depth dims =
  71. let rec do_expand dims = function
  72. | [] -> raise InvalidNode
  73. | [node] -> dim_reduce depth node
  74. | i :: j :: tl ->
  75. let parent_width = List.hd dims in
  76. let mul = Binop (Mul, dim_reduce depth i, parent_width, [Type Int]) in
  77. do_expand (List.tl dims) (Binop (Add, mul, j, [Type Int]) :: tl)
  78. in
  79. let use_dim = function
  80. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  81. | node -> node
  82. in
  83. do_expand (List.map use_dim (List.tl dims))
  84. (* Transform multi-dimensional arrays into one-dimensional arrays in row-major
  85. * order *)
  86. and dim_reduce depth = function
  87. (* Simplify array types in declarations *)
  88. | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
  89. GlobalDef (export, Array ctype, name, None, ann)
  90. | GlobalDec (ArrayDims (ctype, _), name, ann) ->
  91. GlobalDec (Array ctype, name, ann)
  92. | Param (ArrayDims (ctype, _), name, ann) ->
  93. Param (Array ctype, name, ann)
  94. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  95. VarDec (Array ctype, name, None, ann)
  96. (* Allocate in rw-major order with the array size being the cartegian product
  97. * of all dimensions *)
  98. | Allocate (dec, dims, ann) -> Allocate (dec, [multiply dims], ann)
  99. (* Increase nesting depth when going into function *)
  100. | FunDef (export, ret_type, name, params, body, ann) ->
  101. let trav = dim_reduce (depth + 1) in
  102. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  103. (* Expand indices when dereferencing *)
  104. | VarUse (dec, Some values, ann) ->
  105. begin
  106. match typeof dec with
  107. | ArrayDims (_, dims) ->
  108. VarUse (dec, Some [expand depth dims values], ann)
  109. | _ -> raise InvalidNode
  110. end
  111. (* Expand indices when assigning to array index *)
  112. | VarLet (dec, Some values, value, ann) ->
  113. begin match typeof dec with
  114. | ArrayDims (_, dims) ->
  115. let value = dim_reduce depth value in
  116. VarLet (dec, Some [expand depth dims values], value, ann)
  117. | _ -> raise InvalidNode
  118. end
  119. | node -> traverse_unit (dim_reduce depth) node
  120. let phase = function
  121. | Ast node -> Ast (expand_dims node |> dim_reduce 0)
  122. | _ -> raise InvalidInput