unroll.ml 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. open Types
  2. open Util
  3. (* Only unroll if the resulting number of statements is at most 20 *)
  4. let may_be_unrolled i_values body =
  5. List.length i_values * List.length body <= 25
  6. let rec range i j step =
  7. if i >= j then [] else i :: (range (i + step) j step)
  8. let rec assigns name = function
  9. | VarLet (dec, _, _, _) -> nameof dec = name
  10. | _ -> false
  11. let rec replace_var name replacement = function
  12. | VarUse (VarDec (_, var, _, _), None, _) when var = name -> replacement
  13. | node -> traverse_unit (replace_var name replacement) node
  14. let rec get_body_step i rest = function
  15. | [] -> None
  16. | [VarLet (
  17. VarDec (Int, assigned, None, _), None,
  18. Binop (
  19. Add,
  20. VarUse (VarDec (Int, added, None, _), None, _),
  21. Const (IntVal step, _),
  22. _
  23. ),
  24. _
  25. )] when assigned = added -> Some (step, List.rev rest)
  26. | hd :: tl -> get_body_step i (hd :: rest) tl
  27. let rec unroll_body counters = function
  28. | [] -> []
  29. (*
  30. * Look for the following pattern:
  31. * i = 0;
  32. * while (a < stop) {
  33. * <body>;
  34. * b = c + step;
  35. * }
  36. * where a = b = c = i and start, stop, step are integer constants and i is a
  37. * generated variable
  38. *)
  39. | (VarLet (VarDec (Int, i, None, _), None, Const (IntVal start, _), _) as init) ::
  40. (While (
  41. Binop (
  42. Lt,
  43. VarUse (VarDec (Int, comp, None, _), None, _),
  44. Const (IntVal stop, _),
  45. _),
  46. Block body,
  47. _) as loop) :: tl
  48. when is_generated_id i & comp = i ->
  49. begin
  50. match get_body_step i [] body with
  51. | Some (step, rest) ->
  52. let rest = flatten_blocks (unroll_body counters rest) in
  53. let i_values = range start stop step in
  54. if may_be_unrolled i_values rest then begin
  55. Hashtbl.add counters i true;
  56. let dup_body value =
  57. replace_var i (Const (IntVal value, [Type Int])) (Block rest)
  58. in
  59. Block (List.map dup_body i_values) :: (unroll_body counters tl)
  60. end else
  61. init :: (unroll counters loop) :: (unroll_body counters tl)
  62. | None -> init :: (unroll counters loop) :: (unroll_body counters tl)
  63. end
  64. | hd :: tl -> (unroll counters hd) :: (unroll_body counters tl)
  65. and unroll counters = function
  66. | Block stats -> Block (unroll_body counters stats)
  67. | node -> traverse_unit (unroll counters) node
  68. let rec prune_vardecs counters = function
  69. | VarDec (_, name, _, _) when Hashtbl.mem counters name -> DummyNode
  70. | node -> traverse_unit (prune_vardecs counters) node
  71. let phase = function
  72. | Ast node ->
  73. let counters = Hashtbl.create 10 in
  74. let node = unroll counters node in
  75. let node = prune_vardecs counters node in
  76. Ast (Constprop.propagate_consts node)
  77. | _ -> raise (InvalidInput "loop unrolling")