open Types open Util (* Only unroll if the resulting number of statements is at most 25 *) let may_be_unrolled i_values body = List.length i_values * List.length body <= 25 let rec range i j step = if i >= j then [] else i :: (range (Int32.add i step) j step) let rec assigns name = function | VarLet (dec, _, _, _) -> nameof dec = name | _ -> false let rec replace_var name replacement = function | VarUse (VarDec (_, var, _, _), None, _) when var = name -> replacement | node -> traverse_unit (replace_var name replacement) node let rec get_body_step i rest = function | [] -> None | [VarLet ( VarDec (Int, assigned, None, _), None, Binop (Add, VarUse (VarDec (Int, added, None, _), None, _), Const (IntVal step, _), _), _)] when assigned = added -> Some (step, List.rev rest) | hd :: tl -> get_body_step i (hd :: rest) tl let rec unroll_body counters = function | [] -> [] (* * Look for the following pattern: * i = 0; * while (i < stop) { * ; * i = i + step; * } * where start, stop, step are integer constants and i is a generated variable *) | (VarLet (VarDec (Int, i, None, _), None, Const (IntVal start, _), _) as init) :: (While ( Binop (Lt, VarUse (VarDec (Int, comp, None, _), None, _), Const (IntVal stop, _), _), Block body, _) as loop) :: tl when is_generated_id i && comp = i -> begin match get_body_step i [] body with | Some (step, rest) -> (* First unroll inner loops in body before the may_be_unrolled check *) let rest = flatten_blocks (unroll_body counters rest) in let i_values = range start stop step in if may_be_unrolled i_values rest then begin Hashtbl.add counters i true; let dup_body value = replace_var i (Const (IntVal value, [Type Int])) (Block rest) in Block (List.map dup_body i_values) :: (unroll_body counters tl) end else init :: (unroll counters loop) :: (unroll_body counters tl) | None -> init :: (unroll counters loop) :: (unroll_body counters tl) end | hd :: tl -> (unroll counters hd) :: (unroll_body counters tl) and unroll counters = function | Block stats -> Block (unroll_body counters stats) | node -> traverse_unit (unroll counters) node let rec prune_vardecs counters = function | VarDec (_, name, _, _) when Hashtbl.mem counters name -> DummyNode | node -> traverse_unit (prune_vardecs counters) node let phase = function | Ast node -> let counters = Hashtbl.create 10 in Ast (unroll counters node |> prune_vardecs counters) | _ -> raise InvalidInput