dimreduce.ml 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. open Types
  2. open Util
  3. let flatten_type = function
  4. | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
  5. GlobalDef (export, Array ctype, name, None, ann)
  6. | GlobalDec (ArrayDims (ctype, _), name, ann) ->
  7. GlobalDec (Array ctype, name, ann)
  8. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  9. VarDec (Array ctype, name, None, ann)
  10. | Param (ArrayDims (ctype, _), name, ann) ->
  11. Param (Array ctype, name, ann)
  12. | _ -> raise InvalidNode
  13. (* Pass array dimensions explicitly to functions *)
  14. let rec expand_dims = function
  15. (* Flatten Block nodes returned by transformations below *)
  16. | FunDef (export, ret_type, name, params, body, ann) ->
  17. let params = flatten_blocks (List.map expand_dims params) in
  18. FunDef (export, ret_type, name, params, expand_dims body, ann)
  19. | FunDec (ret_type, name, params, ann) ->
  20. let params = flatten_blocks (List.map expand_dims params) in
  21. FunDec (ret_type, name, params, ann)
  22. | FunUse (dec, params, ann) ->
  23. FunUse (dec, flatten_blocks (List.map expand_dims params), ann)
  24. (* Add additional parameters for array dimensions *)
  25. | Param (ArrayDims (ctype, dims), name, ann) ->
  26. let rec do_expand = function
  27. | [] -> [Param (Array ctype, name, ann)]
  28. | Dim (name, ann) :: tail ->
  29. Param (Int, name, ann) :: (do_expand tail)
  30. | _ -> raise InvalidNode
  31. in
  32. Block (do_expand dims)
  33. (* Add additional function arguments for array dimensions *)
  34. | Arg (VarUse (dec, None, ann)) when is_array dec ->
  35. let make_dimdec = function
  36. | Dim (name, ann) -> Param (Int, name, ann)
  37. | _ -> raise InvalidNode
  38. in
  39. let rec do_expand = function
  40. | [] ->
  41. (* Remove the (now obsolete dimensions from the type) *)
  42. [VarUse (flatten_type dec, None, ann)]
  43. | hd :: tl ->
  44. (* A declaration has been added for each dimension during earlier
  45. * phases, so we can safely reconstruct it here *)
  46. Arg (VarUse (make_dimdec hd, None, [])) :: (do_expand tl)
  47. in
  48. let dims = match typeof dec with
  49. | ArrayDims (_, dims) -> dims
  50. | _ -> raise InvalidNode
  51. in
  52. Block (do_expand dims)
  53. | node -> traverse_unit expand_dims node
  54. let rec multiply = function
  55. | [] -> raise InvalidNode
  56. | [node] -> node
  57. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  58. let rec multiply_all = function
  59. | [] -> raise InvalidNode
  60. | [node] -> node
  61. | hd :: tl -> Binop (Mul, hd, multiply_all tl, [])
  62. let rec expand depth dims =
  63. let rec do_expand dims = function
  64. | [] -> raise InvalidNode
  65. | [node] -> dim_reduce depth node
  66. | i :: j :: tl ->
  67. let parent_width = List.hd dims in
  68. let mul = Binop (Mul, dim_reduce depth i, parent_width, [Type Int]) in
  69. do_expand (List.tl dims) (Binop (Add, mul, j, [Type Int]) :: tl)
  70. in
  71. let use_dim = function
  72. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  73. | node -> node
  74. in
  75. do_expand (List.map use_dim (List.tl dims))
  76. (* Transform multi-dimensional arrays into one-dimensional arrays in row-major
  77. * order *)
  78. and dim_reduce depth = function
  79. | Allocate (dec, dims, ann) ->
  80. Allocate (dec, [multiply dims], ann)
  81. (* Simplify array types in declarations *)
  82. | GlobalDef (export, ArrayDims (ctype, _), name, None, ann) ->
  83. GlobalDef (export, Array ctype, name, None, ann)
  84. | GlobalDec (ArrayDims (ctype, _), name, ann) ->
  85. GlobalDec (Array ctype, name, ann)
  86. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  87. VarDec (Array ctype, name, None, ann)
  88. (* Increase nesting depth when goiing into function *)
  89. | FunDef (export, ret_type, name, params, body, ann) ->
  90. let trav = dim_reduce (depth + 1) in
  91. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  92. (* Expand indices when dereferencing *)
  93. | VarUse (dec, Some values, ann) ->
  94. begin match typeof dec with
  95. | ArrayDims (_, dims) ->
  96. VarUse (dec, Some [expand depth dims values], ann)
  97. | _ -> raise InvalidNode
  98. end
  99. (* Expand indices when assigning to array index *)
  100. | VarLet (dec, Some values, value, ann) ->
  101. begin match typeof dec with
  102. | ArrayDims (_, dims) ->
  103. let value = dim_reduce depth value in
  104. VarLet (dec, Some [expand depth dims values], value, ann)
  105. | _ -> raise InvalidNode
  106. end
  107. | node -> traverse_unit (dim_reduce depth) node
  108. let phase = function
  109. | Ast node -> Ast (dim_reduce 0 (expand_dims node))
  110. | _ -> raise InvalidInput