desug.ml 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. open Printf
  2. open Types
  3. open Util
  4. (* Generate new variables for array dimensions in function bodies, to avoid
  5. * re-evalutation after array dimension reduction. For example:
  6. *
  7. * int dims = 0;
  8. *
  9. * int dim() {
  10. * dims = dims 1; // Side effect => dims() should be called once
  11. * return 10;
  12. * }
  13. *
  14. * void foo() {
  15. * int[10, dim()] arr;
  16. * arr[0, 1] = 1;
  17. * }
  18. *
  19. * After dimension reduction, this would become:
  20. * void foo() {
  21. * int[] arr;
  22. * arr = allocate(10, dim());
  23. * arr[1 * dim() + 0] = 1;
  24. * }
  25. *
  26. * This behaviour is of course incorrect. To avoid dim() from being evaluated
  27. * twice, the snippet above is transformed into the following code: (note the $$
  28. * which will help later during constant propagation)
  29. * void foo() {
  30. * int[a$dim$$1, a$dim$$2] arr;
  31. * a$dim$$1 = 10;
  32. * a$dim$$2 = dim();
  33. * arr[1, 2] = 1;
  34. * }
  35. *
  36. * ... which later becomes:
  37. * void foo() {
  38. * int[a$dim$$1, a$dim$$2] arr;
  39. * a$dim$$1 = 10;
  40. * a$dim$$2 = dim();
  41. * arr = __allocate(a$dim$$1 * a$dim$$2);
  42. * arr[1 * a$dim$2 * 0] = 1;
  43. * }
  44. * *)
  45. let rec array_dims node =
  46. let make_dims basename values make_dec =
  47. let make_name i _ = basename ^ "$dim$$" ^ string_of_int (i + 1) in
  48. let names = mapi make_name values in
  49. let decs = List.map2 make_dec values names in
  50. let make_dim value name = Dim (name, annof value) in
  51. let dims = List.map2 make_dim values names in
  52. (decs, dims)
  53. in
  54. match node with
  55. | VarDec (ArrayDims (ctype, values), name, init, ann) ->
  56. let make_dec value name = VarDec (Int, name, Some value, []) in
  57. let (decs, dims) = make_dims name values make_dec in
  58. Block (decs @ [VarDec (ArrayDims (ctype, dims), name, init, ann)])
  59. | GlobalDef (export, ArrayDims (ctype, values), name, None, ann) ->
  60. let make_dec value name = GlobalDef (export, Int, name, Some value, []) in
  61. let (decs, dims) = make_dims name values make_dec in
  62. Block (decs @ [GlobalDef (export, ArrayDims (ctype, dims), name, None, ann)])
  63. | GlobalDec (ArrayDims (ctype, dims), name, ann) ->
  64. let rec make_decs = function
  65. | [] -> []
  66. | Dim (name, ann) :: tl -> GlobalDec (Int, name, ann) :: (make_decs tl)
  67. | _ -> raise InvalidNode
  68. in
  69. let decs = make_decs dims in
  70. Block (decs @ [GlobalDec (ArrayDims (ctype, dims), name, ann)])
  71. | node -> transform_children array_dims node
  72. (* Split variable declaration and initialisation *)
  73. let rec split_inits = function
  74. (* Translate scalar array initialisation to ArrayScalar node,
  75. * for easy replacement later on *)
  76. | VarDec (ArrayDims (_, dims) as ctype, name, Some (Const _ as v), ann) ->
  77. let init = Some (ArrayInit (ArrayScalar v, dims)) in
  78. split_inits (VarDec (ctype, name, init, ann))
  79. (* Wrap ArrayConst in ArrayInit to pass dimensions *)
  80. | VarDec (ArrayDims (_, dims) as ctype, name, Some (ArrayConst _ as v), ann) ->
  81. let init = Some (ArrayInit (v, dims)) in
  82. split_inits (VarDec (ctype, name, init, ann))
  83. (* Variable initialisations are split into dec;assign *)
  84. | VarDec (ctype, name, Some init, ann) ->
  85. Block [
  86. VarDec (ctype, name, None, ann);
  87. Assign (name, None, init, ann);
  88. ]
  89. | GlobalDef (export, ctype, name, Some init, ann) ->
  90. Block [
  91. GlobalDef (export, ctype, name, None, ann);
  92. Assign (name, None, init, ann);
  93. ]
  94. | node -> transform_children split_inits node
  95. (* Add <allocate> statements after array declarations *)
  96. let rec add_allocs node =
  97. let create_dimvar = function
  98. | Dim (name, _) -> Var (name, None, [])
  99. | _ -> raise InvalidNode
  100. in
  101. match node with
  102. | VarDec (ArrayDims (_, dims), _, _, ann) ->
  103. Block [node; Allocate (node, List.map create_dimvar dims, ann)]
  104. | GlobalDef (_, ArrayDims (_, dims), _, _, ann) ->
  105. Block [node; Allocate (node, List.map create_dimvar dims, ann)]
  106. | node -> transform_children add_allocs node
  107. let extract_inits lst =
  108. let rec trav inits = function
  109. | [] ->
  110. (List.rev inits, [])
  111. | (Assign _ as hd) :: tl
  112. | (Allocate _ as hd) :: tl ->
  113. trav (hd :: inits) tl
  114. | hd :: tl ->
  115. let (inits, tl) = trav inits tl in
  116. (inits, (hd :: tl))
  117. in trav [] lst
  118. let rec move_inits = function
  119. (* Move global initialisations to __init function *)
  120. | Program (decls, ann) ->
  121. let decls = List.map move_inits decls in
  122. (match extract_inits decls with
  123. | ([], _) -> Program (decls, ann)
  124. | (inits, decls) ->
  125. let init_func = FunDef (true, Void, "__init", [], Block inits, []) in
  126. Program (init_func :: decls, ann)
  127. )
  128. (* Split local variable initialisations in declaration and assignment *)
  129. | FunDef (export, ret_type, name, params, Block body, ann) ->
  130. let rec place_inits inits = function
  131. | VarDecs lst :: tl ->
  132. let (inits, decs) = extract_inits lst in
  133. VarDecs decs :: (place_inits inits tl)
  134. | LocalFuns _ as hd :: tl ->
  135. hd :: inits @ tl
  136. | _ -> raise InvalidNode
  137. in
  138. let body = Block (place_inits [] body) in
  139. FunDef (export, ret_type, name, params, body, ann)
  140. | node -> transform_children move_inits node
  141. let for_to_while node =
  142. let rec replace_var var replacement node =
  143. let trav = (replace_var var replacement) in
  144. match node with
  145. | Var (name, None, ann) when name = var ->
  146. Var (replacement, None, ann)
  147. | For (counter, start, stop, step, body, ann) when counter = var ->
  148. For (replacement, trav start, trav stop, trav step, trav body, ann)
  149. | node ->
  150. transform_children trav node
  151. in
  152. let rec traverse new_vars = function
  153. | FunDef (export, ret_type, name, params, body, ann) ->
  154. let new_vars = ref [] in
  155. let body = traverse new_vars body in
  156. let create_vardec name = VarDec (Int, name, None, []) in
  157. let new_vardecs = List.map create_vardec !new_vars in
  158. let _body = new_vardecs @ (flatten_blocks (block_body body)) in
  159. FunDef (export, ret_type, name, params, Block _body, ann)
  160. (* Transform for-loops to while-loops *)
  161. | For (counter, start, stop, step, body, ann) ->
  162. let _i = fresh_var counter in
  163. let _stop = fresh_const "stop" in
  164. let _step = fresh_const "step" in
  165. new_vars := !new_vars @ [_i; _stop; _step];
  166. let vi = Var (_i, None, []) in
  167. let vstop = Var (_stop, None, annof stop) in
  168. let vstep = Var (_step, None, annof step) in
  169. let cond = Cond (
  170. Binop (Gt, vstep, Const (IntVal 0, []), []),
  171. Binop (Lt, vi, vstop, []),
  172. Binop (Gt, vi, vstop, []),
  173. []
  174. ) in
  175. Block [
  176. Assign (_i, None, start, annof start);
  177. Assign (_stop, None, stop, annof stop);
  178. Assign (_step, None, step, annof step);
  179. traverse new_vars (While (cond, (Block (
  180. block_body (replace_var counter _i body) @
  181. [Assign (_i, None, Binop (Add, vi, vstep, []), [])]
  182. )), ann));
  183. ]
  184. (* DISABLED, while-loops are explicitly supported by the assembly phase
  185. (* Transform while-loops to do-while loops in if-statements *)
  186. | While (cond, body, ann) ->
  187. let cond = traverse new_vars cond in
  188. let body = traverse new_vars body in
  189. Block [If (cond, Block [DoWhile (cond, body, ann)], ann)]
  190. *)
  191. | node -> transform_children (traverse new_vars) node
  192. in
  193. traverse (ref []) node
  194. let rec array_init = function
  195. (* Transform scalar assignment into nested for-loops *)
  196. | Assign (name, None, ArrayInit (ArrayScalar value, dims), ann) ->
  197. let rec add_loop indices = function
  198. | [] ->
  199. Assign (name, Some indices, value, ann)
  200. | dim :: rest ->
  201. let counter = fresh_var "i" in
  202. let body = Block [add_loop (indices @ [Var (counter, None, [])]) rest] in
  203. For (counter, Const (IntVal 0, []), dim, Const (IntVal 1, []), body, [])
  204. in
  205. add_loop [] dims
  206. (* Transform array constant inisialisation into separate assign statements
  207. * for all entries in the constant array *)
  208. (* TODO: only allow when array dimensions are constant? *)
  209. | Assign (name, None, ArrayInit (ArrayConst _ as value, dims), ann) ->
  210. let ndims = List.length dims in
  211. let rec make_assigns depth i indices = function
  212. | [] -> []
  213. | hd :: tl ->
  214. let assigns = traverse depth (i :: indices) hd in
  215. make_assigns depth (i + 1) indices tl @ assigns
  216. and traverse depth indices = function
  217. | ArrayConst (values, _) ->
  218. make_assigns (depth + 1) 0 indices values
  219. | value when depth = ndims ->
  220. let indices = List.map (fun i -> Const (IntVal i, [])) indices in
  221. [Assign (name, Some (List.rev indices), value, ann)]
  222. | node ->
  223. let msg = sprintf
  224. "dimension mismatch: expected %d nesting levels, got %d"
  225. ndims depth
  226. in
  227. raise (NodeError (node, msg))
  228. in
  229. Block (List.rev (traverse 0 [] value))
  230. | node -> transform_children array_init node
  231. let phase = function
  232. | Ast node ->
  233. let node = move_inits (add_allocs (split_inits (array_dims node))) in
  234. Ast (for_to_while (array_init (node)))
  235. | _ -> raise (InvalidInput "desugar")