Jelajahi Sumber

Added a rudimentary form of constant propagation for reducing for-loop complexity, improved flatten_blocks

Taddeus Kroes 12 tahun lalu
induk
melakukan
669f2c21f3
8 mengubah file dengan 201 tambahan dan 20 penghapusan
  1. 2 2
      Makefile
  2. 9 7
      main.ml
  3. 147 0
      phases/constant_propagation.ml
  4. 2 2
      phases/desug.ml
  5. 0 2
      test/array_init.cvc
  6. 9 0
      test/constant_propagation.cvc
  7. 28 6
      util.ml
  8. 4 1
      util.mli

+ 2 - 2
Makefile

@@ -1,6 +1,6 @@
 RESULT := civicc
-PHASES := load parse print desug context_analysis expand_dims typecheck \
-	dim_reduce bool_op extern_vars
+PHASES := load parse print desug constant_propagation context_analysis \
+	expand_dims typecheck dim_reduce bool_op extern_vars
 SOURCES := ast.ml stringify.mli stringify.ml util.mli util.ml lexer.mll \
 	parser.mly $(patsubst %,phases/%.ml,$(PHASES)) main.ml
 PRE_TARGETS := ast.cmi ast.o stringify.cmi stringify.o util.cmi util.o

+ 9 - 7
main.ml

@@ -14,21 +14,23 @@ let compile () =
         Load.phase;
         (*Print.phase;*)
         Parse.phase;
-        (*Print.phase*)
+        (*Print.phase;*)
         Desug.phase;
         Print.phase;
-        Context_analysis.phase;
+        Constant_propagation.phase;
         Print.phase;
+        Context_analysis.phase;
+        (*Print.phase;*)
         Typecheck.phase;
-        (*Print.phase*)
+        (*Print.phase;*)
         Expand_dims.phase;
-        (*Print.phase*)
+        (*Print.phase;*)
         Bool_op.phase;
-        (*Print.phase*)
+        (*Print.phase;*)
         Dim_reduce.phase;
-        Print.phase;
+        (*Print.phase;*)
         Extern_vars.phase;
-        Print.phase;
+        (*Print.phase;*)
         (*
         Assemble.phase;
         Print.phase;

+ 147 - 0
phases/constant_propagation.ml

@@ -0,0 +1,147 @@
+(**
+ * The compiler sometimes generates variables of the form foo$1, to make sure
+ * that expressions are only executed once. In many cases, this leads to
+ * over-complex constructions, for example when converting for-loops to
+ * while-loops. We use the knowledge of these variables being constant by
+ * propagation the constant values to their occurrences, and then apply
+ * arithmetic simplification to operators to reduce the size and complexity of
+ * the generated code. Note that this can only be applied to constants. For
+ * variables in general, some form of liveness analysis would be required (e.g.
+ * Static Single Assignment form). Expressions can only be propagated when they
+ * have no side effects, i.e. when they do not contain function calls.
+ *)
+open Ast
+open Util
+
+let is_const_name name =
+    Str.string_match (Str.regexp "[^\\$]+\\$\\$[0-9]+") name 0
+
+let is_const = function
+    | BoolConst _ | IntConst _ | FloatConst _ -> true
+    | _ -> false
+
+let eval_monop = function
+    | (Not, BoolConst  (value, _), loc) -> BoolConst  (not value, loc)
+    | (Neg, IntConst   (value, _), loc) -> IntConst   (-value, loc)
+    | (Neg, FloatConst (value, _), loc) -> FloatConst (-.value, loc)
+    | (op, opnd, loc) -> Monop (op, opnd, loc)
+
+let eval_binop = function
+    (* Arithmetic *)
+    | (Add, IntConst (left, _), IntConst (right, _), loc) ->
+        IntConst (left + right, loc)
+    | (Add, FloatConst (left, _), FloatConst (right, _), loc) ->
+        FloatConst (left +. right, loc)
+
+    | (Sub, IntConst (left, _), IntConst (right, _), loc) ->
+        IntConst (left - right, loc)
+    | (Sub, FloatConst (left, _), FloatConst (right, _), loc) ->
+        FloatConst (left -. right, loc)
+
+    | (Mul, IntConst (left, _), IntConst (right, _), loc) ->
+        IntConst (left * right, loc)
+    | (Mul, FloatConst (left, _), FloatConst (right, _), loc) ->
+        FloatConst (left *. right, loc)
+
+    | (Div, IntConst (left, _), IntConst (right, _), loc) ->
+        IntConst (left / right, loc)
+    | (Div, FloatConst (left, _), FloatConst (right, _), loc) ->
+        FloatConst (left /. right, loc)
+
+    | (Mod, IntConst (left, _), IntConst (right, _), loc) ->
+        IntConst (left mod right, loc)
+
+    (* Relational *)
+    | (Eq, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left = right, loc)
+    | (Eq, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left = right, loc)
+
+    | (Ne, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left != right, loc)
+    | (Ne, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left != right, loc)
+
+    | (Gt, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left > right, loc)
+    | (Gt, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left > right, loc)
+
+    | (Lt, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left < right, loc)
+    | (Lt, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left < right, loc)
+
+    | (Ge, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left >= right, loc)
+    | (Ge, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left >= right, loc)
+
+    | (Le, IntConst (left, _), IntConst (right, _), loc) ->
+        BoolConst (left <= right, loc)
+    | (Le, FloatConst (left, _), FloatConst (right, _), loc) ->
+        BoolConst (left <= right, loc)
+
+    (* Logical *)
+    | (And, BoolConst (left, _), BoolConst (right, _), loc) ->
+        BoolConst (left && right, loc)
+    | (Or, BoolConst (left, _), BoolConst (right, _), loc) ->
+        BoolConst (left || right, loc)
+
+    | (op, left, right, loc) -> Binop (op, left, right, loc)
+
+let rec propagate consts node =
+    let propagate = propagate consts in
+    match node with
+
+    (* Constant assignments are added to constants table *)
+    | Assign (name, None, value, loc) when is_const_name name ->
+        let value = propagate value in
+        if is_const value then (
+            Hashtbl.add consts name value;
+            DummyNode
+        ) else
+            Assign (name, None, value, loc)
+
+    (* Variables that are in the constant table are replaced with their constant
+     * value *)
+    | Var (name, loc) when Hashtbl.mem consts name ->
+        Hashtbl.find consts name
+
+    (* Apply arithmetic simplification to constant operands *)
+    | Monop (op, opnd, loc) ->
+        let opnd = propagate opnd in
+        if is_const opnd
+            then eval_monop (op, opnd, loc)
+            else Monop (op, opnd, loc)
+
+    | Binop (op, left, right, loc) ->
+        let left = propagate left in
+        let right = propagate right in
+        if is_const left && is_const right
+            then eval_binop (op, left, right, loc)
+            else Binop (op, left, right, loc)
+
+    | Cond (cond, texp, fexp, loc) ->
+        let cond = propagate cond in
+        let texp = propagate texp in
+        let fexp = propagate fexp in
+        (match cond with
+        | BoolConst (value, _) -> if value then texp else fexp
+        | _ -> Cond (cond, texp, fexp, loc)
+        )
+
+    | node -> transform_children propagate node
+
+let rec prune_vardecs consts = function
+    | VarDec (ctype, name, init, loc) when Hashtbl.mem consts name -> Block []
+    | node -> transform_children (prune_vardecs consts) node
+
+let rec phase input =
+    prerr_endline "- Constant propagation";
+    match input with
+    | Ast node ->
+        let consts = (Hashtbl.create 32) in
+        let node = propagate consts node in
+        Ast (prune_vardecs consts node)
+    | _ -> raise (InvalidInput "constant propagation")

+ 2 - 2
phases/desug.ml

@@ -119,8 +119,8 @@ let for_to_while node =
         (* Transform for-loops to while-loops *)
         | For (counter, start, stop, step, body, loc) ->
             let _i = fresh_var counter in
-            let _stop = fresh_var "stop" in
-            let _step = fresh_var "step" in
+            let _stop = fresh_const "stop" in
+            let _step = fresh_const "step" in
             new_vars := !new_vars @ [_i; _stop; _step];
 
             let vi = Var (_i, noloc) in

+ 0 - 2
test/array_init.cvc

@@ -1,8 +1,6 @@
 extern void printInt(int i);
 extern void printNewlines(int i);
 
-int[i, j] glob;
-
 export int main() {
     int[2,3] a = [[1,2,3], [4,5,6]];
     int[2,3] b = 7;

+ 9 - 0
test/constant_propagation.cvc

@@ -0,0 +1,9 @@
+extern void printInt(int val);
+
+export int main(int a, int b) {
+    for (int i = a, b) {
+        printInt(i);
+    }
+
+    return 0;
+}

+ 28 - 6
util.ml

@@ -24,6 +24,10 @@ let fresh_var prefix =
     var_counter := !var_counter + 1;
     prefix ^ "$" ^ string_of_int !var_counter
 
+(* Constants are marked by a double $$ for recognition during constant
+ * propagation *)
+let fresh_const prefix = fresh_var (prefix ^ "$")
+
 let loc_from_lexpos pstart pend =
     let (fname, ystart, yend, xstart, xend) = (
         pstart.pos_fname,
@@ -37,13 +41,36 @@ let loc_from_lexpos pstart pend =
     else
         (fname, ystart, yend, xstart, xend)
 
+let rec flatten_blocks lst =
+    let flatten = flatten_blocks in
+    let trav = function
+        | FunDef (export, ret_type, name, params, Block body, loc) ->
+            FunDef (export, ret_type, name, flatten params, Block (flatten body), loc)
+        | If (cond, Block body, loc) ->
+            If (cond, Block (flatten body), loc)
+        | IfElse (cond, Block tbody, Block fbody, loc) ->
+            IfElse (cond, Block (flatten tbody), Block (flatten fbody), loc)
+        | While (cond, Block body, loc) ->
+            While (cond, Block (flatten body), loc)
+        | DoWhile (cond, Block body, loc) ->
+            DoWhile (cond, Block (flatten body), loc)
+        | For (counter, start, stop, step, Block body, loc) ->
+            For (counter, start, stop, step, Block (flatten body), loc)
+        | node -> node
+    in
+    match lst with
+    | []                -> []
+    | Block nodes :: tl -> flatten nodes @ (flatten tl)
+    | DummyNode :: tl   -> flatten tl
+    | hd :: tl          -> trav hd :: (flatten tl)
+
 (* Default tree transformation
  * (node -> node) -> node -> node *)
 let transform_children trav node =
     let trav_all nodes = List.map trav nodes in
     match node with
     | Program (decls, loc) ->
-        Program (trav_all decls, loc)
+        Program (flatten_blocks (trav_all decls), loc)
     | FunDec (ret_type, name, params, loc) ->
         FunDec (ret_type, name, trav_all params, loc)
     | FunDef (export, ret_type, name, params, body, loc) ->
@@ -194,11 +221,6 @@ let prerr_loc_msg loc msg verbose =
     if verbose >= 2 then prerr_loc loc;
     ()
 
-let rec flatten_blocks = function
-    | [] -> []
-    | Block nodes :: t -> (flatten_blocks nodes) @ (flatten_blocks t)
-    | h :: t -> h :: (flatten_blocks t)
-
 let ctypeof = function
     | VarDec (ctype, _, _, _)
     | Param (ctype, _, _)

+ 4 - 1
util.mli

@@ -6,9 +6,12 @@ val log_node : int -> Ast.node -> unit
 val dbg_line : string -> unit
 val dbg_node : Ast.node -> unit
 
-(* Generate a fresh variable from a given prefix, e.g. "counter" -> "counter$1"  *)
+(* Generate a fresh variable from a given prefix, e.g. "foo" -> "foo$1"  *)
 val fresh_var : string -> string
 
+(* Generate a fresg constant from a given prefix, e.g. "foo" -> "foo$$1"  *)
+val fresh_const : string -> string
+
 (* Generate an Ast.loc tuple from Lexing data structures *)
 val loc_from_lexpos : Lexing.position -> Lexing.position -> Ast.loc