dim_reduce.ml 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. open Types
  2. open Util
  3. let rec expand_dims = function
  4. (* Flatten Block nodes returned by transformations below *)
  5. | FunDef (export, ret_type, name, params, body, ann) ->
  6. let params = flatten_blocks (List.map expand_dims params) in
  7. FunDef (export, ret_type, name, params, expand_dims body, ann)
  8. | FunDec (ret_type, name, params, ann) ->
  9. let params = flatten_blocks (List.map expand_dims params) in
  10. FunDec (ret_type, name, params, ann)
  11. | FunUse (dec, params, ann) ->
  12. FunUse (dec, flatten_blocks (List.map expand_dims params), ann)
  13. (* Add additional parameters for array dimensions *)
  14. | Param (ArrayDims (ctype, dims), name, ann) ->
  15. let rec do_expand = function
  16. | [] -> [Param (Array ctype, name, ann)]
  17. | Dim (name, ann) :: tail ->
  18. Param (Int, name, ann) :: (do_expand tail)
  19. | _ -> raise InvalidNode
  20. in
  21. Block (do_expand dims)
  22. (* Add additional function arguments for array dimensions *)
  23. | Arg (VarUse (VarDec (ArrayDims (ctype, dims), name, None, decann), None, ann)) as node ->
  24. let rec do_expand = function
  25. | [] ->
  26. (* Remove the (now obsolete dimensions fromt the type) *)
  27. let dec = VarDec (Array ctype, name, None, decann) in
  28. [VarUse (dec, None, ann)]
  29. | hd :: tl ->
  30. (* A VarDec node has been added for each dimension during
  31. * desugaring, so we can safely reconstruct it here (we need no
  32. * refrence because the type is immutable, yay!) *)
  33. let dimdec = VarDec (Int, nameof hd, None, annof hd) in
  34. Arg (VarUse (dimdec, None, [])) :: (do_expand tl)
  35. in
  36. Block (do_expand dims)
  37. (* Simplify array types in declarations *)
  38. | VarDec (ArrayDims (ctype, _), name, None, ann) ->
  39. VarDec (Array ctype, name, None, ann)
  40. | node -> transform_children expand_dims node
  41. let rec multiply = function
  42. | [] -> raise InvalidNode
  43. | [node] -> node
  44. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  45. let use_dim depth = function
  46. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  47. (*| VarUse (dim, None, ann) -> VarUse ()*)
  48. | node -> node
  49. let rec expand depth dims = function
  50. | [] -> raise InvalidNode
  51. | [node] -> dim_reduce depth node
  52. | hd :: tl ->
  53. let dim = use_dim depth (List.hd dims) in
  54. let mul = Binop (Mul, dim_reduce depth hd, dim, [Type Int]) in
  55. Binop (Add, mul, expand depth (List.tl dims) tl, [Type Int])
  56. and dim_reduce depth = function
  57. | Allocate (dec, dims, ann) ->
  58. Allocate (dec, [multiply dims], ann)
  59. (* Increase nesting depth when goiing into function *)
  60. | FunDef (export, ret_type, name, params, body, ann) ->
  61. let trav = dim_reduce (depth + 1) in
  62. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  63. (* Expand indices when dereferencing *)
  64. | VarUse (dec, Some values, ann) as node ->
  65. (match typeof dec with
  66. | ArrayDims (_, dims) ->
  67. VarUse (dec, Some [expand depth (List.rev dims) values], ann)
  68. | _ -> node
  69. )
  70. (* Expand indices when assigning to array index *)
  71. | VarLet (dec, Some values, value, ann) as node ->
  72. (match typeof dec with
  73. | ArrayDims (_, dims) ->
  74. VarLet (dec, Some [expand depth (List.rev dims) values], value, ann)
  75. | _ -> node
  76. )
  77. | node -> transform_children (dim_reduce depth) node
  78. let phase = function
  79. | Ast node -> Ast (dim_reduce 0 (expand_dims node))
  80. | _ -> raise (InvalidInput "dimension reduction")