unroll.ml 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. open Types
  2. open Util
  3. (* Only unroll if the resulting number of statements is at most 20 *)
  4. let may_be_unrolled i_values body =
  5. List.length i_values * List.length body <= 25
  6. let is_generated s = Str.string_match (Str.regexp "^.+\\$[0-9]+$") s 0
  7. let rec range i j step =
  8. if i >= j then [] else i :: (range (i + step) j step)
  9. let rec assigns name = function
  10. | VarLet (dec, _, _, _) -> nameof dec = name
  11. | _ -> false
  12. let rec replace_var name replacement = function
  13. | VarUse (VarDec (_, var, _, _), None, _) when var = name -> replacement
  14. | node -> transform_children (replace_var name replacement) node
  15. let rec get_body_step i rest = function
  16. | [] -> None
  17. | [VarLet (
  18. VarDec (Int, assigned, None, _), None,
  19. Binop (
  20. Add,
  21. VarUse (VarDec (Int, added, None, _), None, _),
  22. Const (IntVal step, _),
  23. _
  24. ),
  25. _
  26. )] when assigned = added -> Some (step, List.rev rest)
  27. | hd :: tl -> get_body_step i (hd :: rest) tl
  28. let rec unroll_body counters = function
  29. | [] -> []
  30. (*
  31. * Look for the following pattern:
  32. * i = 0;
  33. * while (a < stop) {
  34. * <body>;
  35. * b = c + step;
  36. * }
  37. * where a = b = c = i and start, stop, step are integer constants and i is a
  38. * generated variable
  39. *)
  40. | (VarLet (VarDec (Int, i, None, _), None, Const (IntVal start, _), _) as init) ::
  41. (While (
  42. Binop (
  43. Lt,
  44. VarUse (VarDec (Int, comp, None, _), None, _),
  45. Const (IntVal stop, _),
  46. _),
  47. Block body,
  48. _) as loop) :: tl
  49. when is_generated i & comp = i ->
  50. begin
  51. match get_body_step i [] body with
  52. | Some (step, rest) ->
  53. let rest = flatten_blocks (unroll_body counters rest) in
  54. let i_values = range start stop step in
  55. if may_be_unrolled i_values rest then begin
  56. Hashtbl.add counters i true;
  57. let dup_body value =
  58. replace_var i (Const (IntVal value, [Type Int])) (Block rest)
  59. in
  60. Block (List.map dup_body i_values) :: (unroll_body counters tl)
  61. end else
  62. init :: (unroll counters loop) :: (unroll_body counters tl)
  63. | None -> init :: (unroll counters loop) :: (unroll_body counters tl)
  64. end
  65. | hd :: tl -> (unroll counters hd) :: (unroll_body counters tl)
  66. and unroll counters = function
  67. | Block stats -> Block (unroll_body counters stats)
  68. | node -> transform_children (unroll counters) node
  69. let rec prune_vardecs counters = function
  70. | VarDec (_, name, _, _) when Hashtbl.mem counters name -> DummyNode
  71. | node -> transform_children (prune_vardecs counters) node
  72. let phase = function
  73. | Ast node ->
  74. let counters = Hashtbl.create 10 in
  75. let node = unroll counters node in
  76. let node = prune_vardecs counters node in
  77. Ast (Constprop.propagate_consts node)
  78. | _ -> raise (InvalidInput "loop unrolling")