(** * The compiler sometimes generates variables of the form foo$1, to make sure * that expressions are only executed once. In many cases, this leads to * over-complex constructions, for example when converting for-loops to * while-loops. We use the knowledge of these variables being constant by * propagation the constant values to their occurrences, and then apply * arithmetic simplification to operators to reduce the size and complexity of * the generated code. Note that this can only be applied to constants. For * variables in general, some form of liveness analysis would be required (e.g. * Static Single Assignment form). Expressions can only be propagated when they * have no side effects, i.e. when they do not contain function calls. *) open Types open Util let is_const_name name = Str.string_match (Str.regexp "^.+\\$\\$[0-9]+$") name 0 let is_const = function Const _ -> true | _ -> false let eval_monop = function | (Not, Const (BoolVal value, _), ann) -> Const (BoolVal (not value), ann) | (Neg, Const (IntVal value, _), ann) -> Const (IntVal (-value), ann) | (Neg, Const (FloatVal value, _), ann) -> Const (FloatVal (-.value), ann) | (op, opnd, ann) -> Monop (op, opnd, ann) let eval_binop = function (* Arithmetic *) | (Add, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left + right), ann) | (Add, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left +. right), ann) | (Sub, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left - right), ann) | (Sub, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left -. right), ann) | (Mul, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left * right), ann) | (Mul, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left *. right), ann) | (Div, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left / right), ann) | (Div, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (FloatVal (left /. right), ann) | (Mod, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (IntVal (left mod right), ann) (* Relational *) | (Eq, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left = right), ann) | (Eq, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left = right), ann) | (Ne, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left != right), ann) | (Ne, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left != right), ann) | (Gt, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left > right), ann) | (Gt, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left > right), ann) | (Lt, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left < right), ann) | (Lt, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left < right), ann) | (Ge, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left >= right), ann) | (Ge, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left >= right), ann) | (Le, Const (IntVal left, _), Const (IntVal right, _), ann) -> Const (BoolVal (left <= right), ann) | (Le, Const (FloatVal left, _), Const (FloatVal right, _), ann) -> Const (BoolVal (left <= right), ann) (* Logical *) | (And, Const (BoolVal left, _), Const (BoolVal right, _), ann) -> Const (BoolVal (left && right), ann) | (Or, Const (BoolVal left, _), Const (BoolVal right, _), ann) -> Const (BoolVal (left || right), ann) | (op, left, right, ann) -> Binop (op, left, right, ann) let rec propagate consts node = let propagate = propagate consts in match node with (* Constant assignments are added to constants table *) | Assign (name, None, value, ann) when is_const_name name -> let value = propagate value in if is_const value then ( Hashtbl.add consts name value; DummyNode ) else Assign (name, None, value, ann) | VarLet (dec, None, value, ann) when is_const_name (nameof dec) -> let value = propagate value in if is_const value then ( Hashtbl.add consts (nameof dec) value; DummyNode ) else VarLet (dec, None, value, ann) (* Variables that are in the constant table are replaced with their constant * value *) | Var (name, None, ann) when Hashtbl.mem consts name -> Hashtbl.find consts name | VarUse (dec, None, ann) when Hashtbl.mem consts (nameof dec) -> Hashtbl.find consts (nameof dec) | Dim (name, ann) when Hashtbl.mem consts name -> Hashtbl.find consts name (* Apply arithmetic simplification to constant operands *) | Monop (op, opnd, ann) -> let opnd = propagate opnd in if is_const opnd then eval_monop (op, opnd, ann) else Monop (op, opnd, ann) | Binop (op, left, right, ann) -> let left = propagate left in let right = propagate right in if is_const left && is_const right then eval_binop (op, left, right, ann) else Binop (op, left, right, ann) | Cond (cond, texp, fexp, ann) -> let cond = propagate cond in let texp = propagate texp in let fexp = propagate fexp in (match cond with | Const (BoolVal value, _) -> if value then texp else fexp | _ -> Cond (cond, texp, fexp, ann) ) | TypeCast (ctype, value, ann) -> let value = propagate value in (match (ctype, value) with | (Bool, Const (BoolVal value, _)) -> Const (BoolVal value, ann) | (Bool, Const (IntVal value, _)) -> Const (BoolVal (value != 1), ann) | (Bool, Const (FloatVal value, _)) -> Const (BoolVal (value != 1.0), ann) | (Int, Const (BoolVal value, _)) -> Const (IntVal (if value then 1 else 0), ann) | (Int, Const (IntVal value, _)) -> Const (IntVal value, ann) | (Int, Const (FloatVal value, _)) -> Const (IntVal (int_of_float value), ann) | (Float, Const (BoolVal value, _)) -> Const (FloatVal (if value then 1. else 0.), ann) | (Float, Const (IntVal value, _)) -> Const (FloatVal (float_of_int value), ann) | (Float, Const (FloatVal value, _)) -> Const (FloatVal value, ann) | _ -> TypeCast (ctype, value, ann) ) | _ -> transform_children propagate node let rec prune_vardecs consts = function | VarDec (ctype, name, init, ann) when Hashtbl.mem consts name -> DummyNode | node -> transform_children (prune_vardecs consts) node let phase = function | Ast node as input -> if args.optimize then ( log_line 1 "- Constant propagation"; let consts = Hashtbl.create 32 in let node = propagate consts node in Ast (prune_vardecs consts node) ) else input | _ -> raise (InvalidInput "constant propagation")