ソースを参照

Added loop unrolling phase

Taddeus Kroes 12 年 前
コミット
65ff039e7c
7 ファイル変更172 行追加8 行削除
  1. 1 1
      Makefile
  2. 2 0
      main.ml
  3. 5 3
      phases/constprop.ml
  4. 4 1
      phases/constprop.mli
  5. 8 3
      phases/desug.ml
  6. 94 0
      phases/unroll.ml
  7. 58 0
      phases/unroll.mli

+ 1 - 1
Makefile

@@ -1,7 +1,7 @@
 RESULT := civicc
 GLOBALS := types globals stringify util
 PHASES := load parse print desug context typecheck extern dimreduce boolop \
-	constprop index assemble peephole output
+	constprop unroll index assemble peephole output
 SOURCES := $(addsuffix .mli,$(GLOBALS)) $(addsuffix .ml,$(GLOBALS)) \
 	lexer.mll parser.mly main.mli \
 	$(patsubst %,phases/%.mli,$(PHASES)) $(patsubst %,phases/%.ml,$(PHASES)) \

+ 2 - 0
main.ml

@@ -29,6 +29,8 @@ let phases = [
    "Convert bool operations");
   ("constprop", Constprop.phase, when_optimize,
    "Constant propagation");
+  ("unroll", Unroll.phase, when_optimize,
+   "Loop unrolling");
   ("index", Index.phase, always,
    "Index analysis");
   ("assemble", Assemble.phase, always,

+ 5 - 3
phases/constprop.ml

@@ -181,9 +181,11 @@ let rec prune_vardecs consts = function
   | VarDec (_, name, _, _) when Hashtbl.mem consts name -> DummyNode
   | node -> transform_children (prune_vardecs consts) node
 
-let phase = function
-  | Ast node ->
+let propagate_consts node =
     let consts = Hashtbl.create 32 in
     let node = propagate consts node in
-    Ast (prune_vardecs consts node)
+    prune_vardecs consts node
+
+let phase = function
+  | Ast node -> Ast (propagate_consts node)
   | _ -> raise (InvalidInput "constant propagation")

+ 4 - 1
phases/constprop.mli

@@ -74,5 +74,8 @@ Constant propagation reduces this to:
 \} v}
     *)
 
-(** Main phase function, called by {!Main}. *)
+(** Constant propagation traversal. Exported for use in {!Unroll}. *)
+val propagate_consts : Types.node -> Types.node
+
+(** Main phase function, called by {!Main}. Calls {!propagate_consts}. *)
 val phase : Main.phase_func

+ 8 - 3
phases/desug.ml

@@ -140,7 +140,8 @@ let rec move_inits = function
     begin match extract_inits decls with
     | ([], _) -> Program (decls, ann)
     | (inits, decls) ->
-      let init_func = FunDef (true, Void, "__init", [], Block inits, []) in
+      let body = Block (VarDecs [] :: LocalFuns [] :: inits) in
+      let init_func = FunDef (true, Void, "__init", [], body, []) in
       Program (init_func :: decls, ann)
     end
 
@@ -172,12 +173,16 @@ let for_to_while node =
   in
   let rec traverse new_vars = function
     | FunDef (export, ret_type, name, params, body, ann) ->
+      let rec place_decs decs = function
+        | Block (VarDecs lst :: tl) -> Block (VarDecs (decs @ lst) :: tl)
+        | _ -> raise InvalidNode
+      in
       let new_vars = ref [] in
       let body = traverse new_vars body in
       let create_vardec name = VarDec (Int, name, None, []) in
       let new_vardecs = List.map create_vardec !new_vars in
-      let _body = new_vardecs @ (flatten_blocks (block_body body)) in
-      FunDef (export, ret_type, name, params, Block _body, ann)
+      let body = place_decs new_vardecs body in
+      FunDef (export, ret_type, name, params, body, ann)
 
     (* Transform for-loops to while-loops *)
     | For (counter, start, stop, step, body, ann) ->

+ 94 - 0
phases/unroll.ml

@@ -0,0 +1,94 @@
+open Types
+open Util
+
+(* Only unroll if the resulting number of statements is at most 20 *)
+let may_be_unrolled i_values body =
+  List.length i_values * List.length body <= 20
+
+let is_generated s = Str.string_match (Str.regexp "^.+\\$[0-9]+$") s 0
+
+let rec range i j step =
+  if i >= j then [] else i :: (range (i + step) j step)
+
+let rec assigns name = function
+  | VarLet (dec, _, _, _) -> nameof dec = name
+  | _ -> false
+
+let rec replace_var name replacement = function
+  | VarUse (VarDec (_, var, _, _), None, _) when var = name -> replacement
+  | node -> transform_children (replace_var name replacement) node
+
+let rec get_body_step i rest = function
+  | [] -> None
+
+  | [VarLet (
+      VarDec (Int, assigned, None, _), None,
+      Binop (
+        Add,
+        VarUse (VarDec (Int, added, None, _), None, _),
+        Const (IntVal step, _),
+        _
+      ),
+      _
+    )] when assigned = added -> Some (step, List.rev rest)
+
+  | hd :: tl -> get_body_step i (hd :: rest) tl
+
+let rec unroll_body counters = function
+  | [] -> []
+
+  (*
+   * Look for the following pattern:
+   * i = 0;
+   * while (a < stop) {
+   *   <body>;
+   *   b = c + step;
+   * }
+   * where a = b = c = i and start, stop, step are integer constants and i is a
+   * generated variable
+   *)
+  | (VarLet (VarDec (Int, i, None, _), None, Const (IntVal start, _), _) as init) ::
+    (While (
+      Binop (
+        Lt,
+        VarUse (VarDec (Int, comp, None, _), None, _),
+        Const (IntVal stop, _),
+      _),
+      Block body,
+    _) as loop) :: tl
+    when is_generated i & comp = i ->
+      begin
+        match get_body_step i [] body with
+        | Some (step, rest) ->
+          let rest = flatten_blocks (unroll_body counters rest) in
+          let i_values = range start stop step in
+
+          if may_be_unrolled i_values rest then begin
+            Hashtbl.add counters i true;
+            let dup_body value =
+              replace_var i (Const (IntVal value, [Type Int])) (Block rest)
+            in
+            Block (List.map dup_body i_values) :: (unroll_body counters tl)
+          end else
+            init :: (unroll counters loop) :: (unroll_body counters tl)
+
+        | None -> init :: (unroll counters loop) :: (unroll_body counters tl)
+      end
+
+  | hd :: tl -> (unroll counters hd) :: (unroll_body counters tl)
+
+and unroll counters = function
+  | Block stats -> Block (unroll_body counters stats)
+  | node -> transform_children (unroll counters) node
+
+let rec prune_vardecs counters = function
+  | VarDec (_, name, _, _) when Hashtbl.mem counters name -> DummyNode
+  | node -> transform_children (prune_vardecs counters) node
+
+let phase = function
+  | Ast node ->
+    let counters = Hashtbl.create 10 in
+    let node = unroll counters node in
+    let node = prune_vardecs counters node in
+    Ast (Constprop.propagate_consts node)
+  | _ -> raise (InvalidInput "loop unrolling")

+ 58 - 0
phases/unroll.mli

@@ -0,0 +1,58 @@
+(** Unroll for-loops with constant boundaries. *)
+
+(** When initializing arrays with scalar values, the desugaring phase creates
+    for-loops, that are later transformed into while-loops. This has an effect
+    on the stack size, which grows because of the increase in variables. The
+    readability of generated code is also affected, since while-loops are very
+    verbose. Constant propagation helps somewhat, but the loops still exist.
+
+    This loop unrolling phase recognizes while-loops that are generated from
+    for-loops. If the upper bound, lower bound, and step variable of the loop
+    are constant integers, the loop is replaced with occurrences of its body for
+    each iteration value. Afterwards, the constant propagation traversal is
+    executed again to utilise new possibilities for constant folding, which are
+    generated by replacing the loop counter variable with a constant value. An
+    example follows:
+{v void foo() \{
+    int[2, 3] arr = 1;
+\} v}
+
+    After desugaring and constant propagation, this has been transformed into:
+
+{v void foo() \{
+    int i$2$4;
+    int i$3$7;
+    int[] arr;
+    arr := <allocate>(6);
+    i$2$4 = 0;
+    while ((i$2$4 < 2)) \{
+        i$3$7 = 0;
+        while ((i$3$7 < 3)) \{
+            arr[((i$2$4 * 3) + i$3$7)] = 1;
+            i$3$7 = (i$3$7 + 1);
+        \}
+        i$2$4 = (i$2$4 + 1);
+    \}
+\} v}
+
+    Now, the loops are unrolled (first the inner loop, then the outer loop) and
+    constant folding is applied to [arr[((i$2$4 * 3) + i$3$7)]] in each
+    iteration:
+
+{v void foo() \{
+    int[] arr;
+    arr := <allocate>(6);
+    arr[0] = 1;
+    arr[1] = 1;
+    arr[2] = 1;
+    arr[3] = 1;
+    arr[4] = 1;
+    arr[5] = 1;
+\} v}
+
+    A simple heuristic is applied to decide whether a recognised for-loop will
+    be unrolled: the resulting number of statements must not be larger than 20.
+    *)
+
+(** Main phase function, called by {!Main}. *)
+val phase : Main.phase_func