| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- (**
- * The compiler sometimes generates variables of the form __foo_1__, to make
- * sure that expressions are only executed once. In many cases, this leads to
- * over-complex constructions, for example when converting for-loops to
- * while-loops. We use the knowledge of these variables being constant by
- * propagation the constant values to their occurrences, and then apply
- * arithmetic simplification to operators to reduce the size and complexity of
- * the generated code. Note that this can only be applied to constants. For
- * variables in general, some form of liveness analysis would be required (e.g.
- * Static Single Assignment form). Expressions can only be propagated when they
- * have no side effects, i.e. when they do not contain function calls.
- *
- * Constant propagation is merged with some some arithmetic simplification here,
- * specifically targeting optimization oppertunities created bij earlier
- * constant propagation. This is utilized, for example, in array index
- * calculation when array dimensions are constant.
- *)
- open Types
- open Util
- let is_const = function
- | Const _ -> true
- | VarUse (dec, None, _) -> is_const_id (nameof dec)
- | Var (name, _, _) -> is_const_id name
- | _ -> false
- (* Only assignments to variables local to this module can be removed, others
- * must stay for consistency when used by other modules *)
- let is_local_dec = function
- | VarDec _ -> true
- | GlobalDef (export, _, _, _, _) -> not export
- | _ -> false
- (* Play-it-safe side effect analysis: only return true for variables and
- * constants, since these are targeted in arithmetic simplification (in
- * particular targeting array indices that can be simplified after array
- * dimension reduction) *)
- let no_side_effect = function
- | Const _ | VarUse _ | Var _ -> true
- | _ -> false
- (* Constand folding + arithmetic popagation *)
- let eval node =
- (* Redefine integer operators within this module since they are only used on
- * IntVal values, which have type int32 *)
- let ( + ) = Int32.add in
- let ( - ) = Int32.sub in
- let ( / ) = Int32.div in
- let ( * ) = Int32.mul in
- let (mod) = Int32.rem in
- match node with
- (* Binop - arithmetic *)
- | Binop (Add, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (IntVal (left + right), ann)
- | Binop (Add, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (FloatVal (left +. right), ann)
- | Binop (Sub, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (IntVal (left - right), ann)
- | Binop (Sub, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (FloatVal (left -. right), ann)
- | Binop (Mul, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (IntVal (left * right), ann)
- | Binop (Mul, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (FloatVal (left *. right), ann)
- | Binop (Div, Const (IntVal left, _), Const (IntVal right, _), ann) when right <> 0l ->
- Const (IntVal (left / right), ann)
- | Binop (Div, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (FloatVal (left /. right), ann)
- | Binop (Mod, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (IntVal (left mod right), ann)
- (* Binop - relational *)
- | Binop (Eq, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left = right), ann)
- | Binop (Eq, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left = right), ann)
- | Binop (Ne, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left <> right), ann)
- | Binop (Ne, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left <> right), ann)
- | Binop (Gt, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left > right), ann)
- | Binop (Gt, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left > right), ann)
- | Binop (Lt, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left < right), ann)
- | Binop (Lt, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left < right), ann)
- | Binop (Ge, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left >= right), ann)
- | Binop (Ge, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left >= right), ann)
- | Binop (Le, Const (IntVal left, _), Const (IntVal right, _), ann) ->
- Const (BoolVal (left <= right), ann)
- | Binop (Le, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
- Const (BoolVal (left <= right), ann)
- (* Binop - logical *)
- | Binop (And, Const (BoolVal left, _), Const (BoolVal right, _), ann) ->
- Const (BoolVal (left && right), ann)
- | Binop (Or, Const (BoolVal left, _), Const (BoolVal right, _), ann) ->
- Const (BoolVal (left || right), ann)
- (* Monary operations *)
- | Monop (Not, Const (BoolVal value, _), ann) ->
- Const (BoolVal (not value), ann)
- | Monop (Neg, Const (IntVal value, _), ann) ->
- Const (IntVal (Int32.neg value), ann)
- | Monop (Neg, Const (FloatVal value, _), ann) ->
- Const (FloatVal (-.value), ann)
- (* 0 * a --> 0 *)
- | Binop (Mul, Const (IntVal 0l, _), other, ann)
- | Binop (Mul, other, Const (IntVal 0l, _), ann) when no_side_effect other ->
- Const (IntVal 0l, ann)
- (* 0 + a --> a *)
- | Binop (Add, Const (IntVal 0l, _), other, _)
- | Binop (Add, other, Const (IntVal 0l, _), _) ->
- other
- (* 1 * a --> a *)
- | Binop (Mul, Const (IntVal 1l, _), other, _)
- | Binop (Mul, other, Const (IntVal 1l, _), _) ->
- other
- (* true|false ? texp : fexp --> texp|fexp*)
- | Cond (Const (BoolVal value, _), texp, fexp, _) ->
- if value then texp else fexp
- | node -> node
- let rec propagate consts node =
- let propagate = propagate consts in
- match node with
- (* Constant assignments are added to constants table *)
- | Assign (name, None, value, ann) when is_const_id name ->
- let value = propagate value in
- if is_const value then begin
- Hashtbl.add consts name value;
- DummyNode
- end else
- Assign (name, None, value, ann)
- | VarLet (dec, None, value, ann) when is_const_id (nameof dec) ->
- let value = propagate value in
- if is_const value then begin
- Hashtbl.add consts (nameof dec) value;
- if is_local_dec dec then
- DummyNode
- else
- VarLet (dec, None, value, ann)
- end else
- VarLet (dec, None, value, ann)
- (* Variables that are in the constant table are replaced with their constant
- * value *)
- | Var (name, None, ann) when Hashtbl.mem consts name ->
- Hashtbl.find consts name
- | VarUse (dec, None, ann) when Hashtbl.mem consts (nameof dec) ->
- Hashtbl.find consts (nameof dec)
- | Dim (name, ann) when Hashtbl.mem consts name ->
- Hashtbl.find consts name
- (* Apply arithmetic simplification to constant operands *)
- | Monop (op, opnd, ann) ->
- eval (Monop (op, propagate opnd, ann))
- | Binop (op, left, right, ann) ->
- eval (Binop (op, propagate left, propagate right, ann))
- | Cond (cond, texp, fexp, ann) ->
- eval (Cond (propagate cond, propagate texp, propagate fexp, ann))
- | TypeCast (ctype, value, ann) ->
- let c v = Const (v, ann) in
- begin match ctype, propagate value with
- | Bool, (Const (BoolVal _, _) as v)
- | Int, (Const (IntVal _, _) as v)
- | Float, (Const (FloatVal _, _) as v) -> v
- | Bool, Const (IntVal v, _) -> c (BoolVal (v != 1l))
- | Bool, Const (FloatVal v, _) -> c (BoolVal (v != 1.0))
- | Int, Const (BoolVal v, _) -> c (IntVal (if v then 1l else 0l))
- | Int, Const (FloatVal v, _) -> c (IntVal (Int32.of_float v))
- | Float, Const (BoolVal v, _) -> c (FloatVal (if v then 1. else 0.))
- | Float, Const (IntVal v, _) -> c (FloatVal (Int32.to_float v))
- | _, v -> TypeCast (ctype, v, ann)
- end
- | _ -> traverse_unit propagate node
- let rec prune_vardecs consts = function
- | VarDec (_, name, _, _) when Hashtbl.mem consts name -> DummyNode
- | node -> traverse_unit (prune_vardecs consts) node
- let propagate_consts node =
- let consts = Hashtbl.create 32 in
- let node = propagate consts node in
- prune_vardecs consts node
- let phase = function
- | Ast node -> Ast (propagate_consts node)
- | _ -> raise InvalidInput
|