(** * The compiler sometimes generates variables of the form __foo_1__, to make * sure that expressions are only executed once. In many cases, this leads to * over-complex constructions, for example when converting for-loops to * while-loops. We use the knowledge of these variables being constant by * propagation the constant values to their occurrences, and then apply * arithmetic simplification to operators to reduce the size and complexity of * the generated code. Note that this can only be applied to constants. For * variables in general, some form of liveness analysis would be required (e.g. * Static Single Assignment form). Expressions can only be propagated when they * have no side effects, i.e. when they do not contain function calls. * * Constant propagation is merged with some some arithmetic simplification here, * specifically targeting optimization oppertunities created bij earlier * constant propagation. This is utilized, for example, in array index * calculation when array dimensions are constant. *) open Types open Util let is_const = function | Const _ -> true | VarUse (dec, None, _) -> is_const_id (nameof dec) | Var (name, _, _) -> is_const_id name | _ -> false (* Only assignments to variables local to this module can be removed, others * must stay for consistency when used by other modules *) let is_local_dec = function | VarDec _ -> true | GlobalDef (export, _, _, _, _) -> not export | _ -> false (* Play-it-safe side effect analysis: only return true for variables and * constants, since these are targeted in arithmetic simplification (in * particular targeting array indices that can be simplified after array * dimension reduction) *) let no_side_effect = function | Const _ | VarUse _ | Var _ -> true | _ -> false (* Constand folding + arithmetic popagation *) let eval node = (* Redefine integer operators within this module since they are only used on * IntVal values, which have type int32 *) let ( + ) = Int32.add in let ( - ) = Int32.sub in let ( / ) = Int32.div in let ( * ) = Int32.mul in let (mod) = Int32.rem in match node with (* Binop - arithmetic *) | Binop (Add, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left + right), ann) | Binop (Add, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left +. right), ann) | Binop (Sub, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left - right), ann) | Binop (Sub, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left -. right), ann) | Binop (Mul, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left * right), ann) | Binop (Mul, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left *. right), ann) | Binop (Div, Const (IntVal left, _), Const (IntVal right, _), ann) when right <> 0l -> Const (IntVal (left / right), ann) | Binop (Div, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left /. right), ann) | Binop (Mod, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left mod right), ann) (* Binop - relational *) | Binop (Eq, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left = right), ann) | Binop (Eq, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left = right), ann) | Binop (Ne, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left <> right), ann) | Binop (Ne, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left <> right), ann) | Binop (Gt, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left > right), ann) | Binop (Gt, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left > right), ann) | Binop (Lt, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left < right), ann) | Binop (Lt, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left < right), ann) | Binop (Ge, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left >= right), ann) | Binop (Ge, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left >= right), ann) | Binop (Le, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left <= right), ann) | Binop (Le, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left <= right), ann) (* Binop - logical *) | Binop (And, Const (BoolVal left, _), Const (BoolVal right, _), ann) -> Const (BoolVal (left && right), ann) | Binop (Or, Const (BoolVal left, _), Const (BoolVal right, _), ann) -> Const (BoolVal (left || right), ann) (* Monary operations *) | Monop (Not, Const (BoolVal value, _), ann) -> Const (BoolVal (not value), ann) | Monop (Neg, Const (IntVal value, _), ann) -> Const (IntVal (Int32.neg value), ann) | Monop (Neg, Const (FloatVal value, _), ann) -> Const (FloatVal (-.value), ann) (* 0 * a --> 0 *) | Binop (Mul, Const (IntVal 0l, _), other, ann) | Binop (Mul, other, Const (IntVal 0l, _), ann) when no_side_effect other -> Const (IntVal 0l, ann) (* 0 + a --> a *) | Binop (Add, Const (IntVal 0l, _), other, _) | Binop (Add, other, Const (IntVal 0l, _), _) -> other (* 1 * a --> a *) | Binop (Mul, Const (IntVal 1l, _), other, _) | Binop (Mul, other, Const (IntVal 1l, _), _) -> other (* true|false ? texp : fexp --> texp|fexp*) | Cond (Const (BoolVal value, _), texp, fexp, _) -> if value then texp else fexp | node -> node let rec propagate consts node = let propagate = propagate consts in match node with (* Constant assignments are added to constants table *) | Assign (name, None, value, ann) when is_const_id name -> let value = propagate value in if is_const value then begin Hashtbl.add consts name value; DummyNode end else Assign (name, None, value, ann) | VarLet (dec, None, value, ann) when is_const_id (nameof dec) -> let value = propagate value in if is_const value then begin Hashtbl.add consts (nameof dec) value; if is_local_dec dec then DummyNode else VarLet (dec, None, value, ann) end else VarLet (dec, None, value, ann) (* Variables that are in the constant table are replaced with their constant * value *) | Var (name, None, ann) when Hashtbl.mem consts name -> Hashtbl.find consts name | VarUse (dec, None, ann) when Hashtbl.mem consts (nameof dec) -> Hashtbl.find consts (nameof dec) | Dim (name, ann) when Hashtbl.mem consts name -> Hashtbl.find consts name (* Apply arithmetic simplification to constant operands *) | Monop (op, opnd, ann) -> eval (Monop (op, propagate opnd, ann)) | Binop (op, left, right, ann) -> eval (Binop (op, propagate left, propagate right, ann)) | Cond (cond, texp, fexp, ann) -> eval (Cond (propagate cond, propagate texp, propagate fexp, ann)) | TypeCast (ctype, value, ann) -> let c v = Const (v, ann) in begin match ctype, propagate value with | Bool, (Const (BoolVal _, _) as v) | Int, (Const (IntVal _, _) as v) | Float, (Const (FloatVal _, _) as v) -> v | Bool, Const (IntVal v, _) -> c (BoolVal (v != 1l)) | Bool, Const (FloatVal v, _) -> c (BoolVal (v != 1.0)) | Int, Const (BoolVal v, _) -> c (IntVal (if v then 1l else 0l)) | Int, Const (FloatVal v, _) -> c (IntVal (Int32.of_float v)) | Float, Const (BoolVal v, _) -> c (FloatVal (if v then 1. else 0.)) | Float, Const (IntVal v, _) -> c (FloatVal (Int32.to_float v)) | _, v -> TypeCast (ctype, v, ann) end | _ -> traverse_unit propagate node let rec prune_vardecs consts = function | VarDec (_, name, _, _) when Hashtbl.mem consts name -> DummyNode | node -> traverse_unit (prune_vardecs consts) node let phase = function | Ast node -> let consts = Hashtbl.create 32 in let node = propagate consts node in Ast (prune_vardecs consts node) | _ -> raise InvalidInput