dimreduce.ml 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. open Types
  2. open Util
  3. (* Pass array dimensions explicitly to functions *)
  4. let rec expand_dims = function
  5. (* Flatten Block nodes returned by transformations below *)
  6. | FunDef (export, ret_type, name, params, body, ann) ->
  7. let params = flatten_blocks (List.map expand_dims params) in
  8. FunDef (export, ret_type, name, params, expand_dims body, ann)
  9. | FunDec (ret_type, name, params, ann) ->
  10. let params = flatten_blocks (List.map expand_dims params) in
  11. FunDec (ret_type, name, params, ann)
  12. | FunUse (dec, params, ann) ->
  13. FunUse (dec, flatten_blocks (List.map expand_dims params), ann)
  14. (* Add additional parameters for array dimensions *)
  15. | Param (ArrayDims (ctype, dims), name, ann) ->
  16. let rec do_expand = function
  17. | [] -> [Param (Array ctype, name, ann)]
  18. | Dim (name, ann) :: tail ->
  19. Param (Int, name, ann) :: (do_expand tail)
  20. | _ -> raise InvalidNode
  21. in
  22. Block (do_expand dims)
  23. (* Add additional function arguments for array dimensions *)
  24. | Arg (VarUse (VarDec (ArrayDims (ctype, dims), name, None, decann), None, ann)) ->
  25. let rec do_expand = function
  26. | [] ->
  27. (* Remove the (now obsolete dimensions from the type) *)
  28. let dec = VarDec (Array ctype, name, None, decann) in
  29. [VarUse (dec, None, ann)]
  30. | hd :: tl ->
  31. (* A VarDec node has been added for each dimension during
  32. * desugaring, so we can safely reconstruct it here (we need no
  33. * refrence because the type is immutable, yay!) *)
  34. let dimdec = VarDec (Int, nameof hd, None, annof hd) in
  35. Arg (VarUse (dimdec, None, [])) :: (do_expand tl)
  36. in
  37. Block (do_expand dims)
  38. | node -> transform_children expand_dims node
  39. let rec multiply = function
  40. | [] -> raise InvalidNode
  41. | [node] -> node
  42. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  43. let rec multiply_all = function
  44. | [] -> raise InvalidNode
  45. | [node] -> node
  46. | hd :: tl -> Binop (Mul, hd, multiply_all tl, [])
  47. let rec expand depth dims =
  48. let rec do_expand dims = function
  49. | [] -> raise InvalidNode
  50. | [node] -> dim_reduce depth node
  51. | i :: j :: tl ->
  52. let parent_width = List.hd dims in
  53. let mul = Binop (Mul, dim_reduce depth i, parent_width, [Type Int]) in
  54. do_expand (List.tl dims) (Binop (Add, mul, j, [Type Int]) :: tl)
  55. in
  56. let use_dim = function
  57. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  58. | node -> node
  59. in
  60. do_expand (List.map use_dim (List.tl dims))
  61. (* Transform multi-dimensional arrays into one-dimensional arrays in row-major
  62. * order *)
  63. and dim_reduce depth = function
  64. | Allocate (dec, dims, ann) ->
  65. Allocate (dec, [multiply dims], ann)
  66. (* Simplify array types in declarations *)
  67. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  68. VarDec (Array ctype, name, None, ann)
  69. (* Increase nesting depth when goiing into function *)
  70. | FunDef (export, ret_type, name, params, body, ann) ->
  71. let trav = dim_reduce (depth + 1) in
  72. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  73. (* Expand indices when dereferencing *)
  74. | VarUse (dec, Some values, ann) as node ->
  75. begin match typeof dec with
  76. | ArrayDims (_, dims) ->
  77. VarUse (dec, Some [expand depth dims values], ann)
  78. | _ -> node
  79. end
  80. (* Expand indices when assigning to array index *)
  81. | VarLet (dec, Some values, value, ann) as node ->
  82. begin match typeof dec with
  83. | ArrayDims (_, dims) ->
  84. VarLet (dec, Some [expand depth dims values], value, ann)
  85. | _ -> node
  86. end
  87. | node -> transform_children (dim_reduce depth) node
  88. let phase = function
  89. | Ast node -> Ast (dim_reduce 0 (expand_dims node))
  90. | _ -> raise (InvalidInput "dimension reduction")