constant_propagation.ml 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. (**
  2. * The compiler sometimes generates variables of the form foo$1, to make sure
  3. * that expressions are only executed once. In many cases, this leads to
  4. * over-complex constructions, for example when converting for-loops to
  5. * while-loops. We use the knowledge of these variables being constant by
  6. * propagation the constant values to their occurrences, and then apply
  7. * arithmetic simplification to operators to reduce the size and complexity of
  8. * the generated code. Note that this can only be applied to constants. For
  9. * variables in general, some form of liveness analysis would be required (e.g.
  10. * Static Single Assignment form). Expressions can only be propagated when they
  11. * have no side effects, i.e. when they do not contain function calls.
  12. *)
  13. open Types
  14. open Util
  15. let is_const_name name =
  16. Str.string_match (Str.regexp "^.+\\$\\$[0-9]+$") name 0
  17. let is_const = function Const _ -> true | _ -> false
  18. let eval_monop = function
  19. | (Not, Const (BoolVal value, _), ann) -> Const (BoolVal (not value), ann)
  20. | (Neg, Const (IntVal value, _), ann) -> Const (IntVal (-value), ann)
  21. | (Neg, Const (FloatVal value, _), ann) -> Const (FloatVal (-.value), ann)
  22. | (op, opnd, ann) -> Monop (op, opnd, ann)
  23. let eval_binop = function
  24. (* Arithmetic *)
  25. | (Add, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  26. Const (IntVal (left + right), ann)
  27. | (Add, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  28. Const (FloatVal (left +. right), ann)
  29. | (Sub, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  30. Const (IntVal (left - right), ann)
  31. | (Sub, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  32. Const (FloatVal (left -. right), ann)
  33. | (Mul, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  34. Const (IntVal (left * right), ann)
  35. | (Mul, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  36. Const (FloatVal (left *. right), ann)
  37. | (Div, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  38. Const (IntVal (left / right), ann)
  39. | (Div, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  40. Const (FloatVal (left /. right), ann)
  41. | (Mod, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  42. Const (IntVal (left mod right), ann)
  43. (* Relational *)
  44. | (Eq, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  45. Const (BoolVal (left = right), ann)
  46. | (Eq, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  47. Const (BoolVal (left = right), ann)
  48. | (Ne, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  49. Const (BoolVal (left <> right), ann)
  50. | (Ne, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  51. Const (BoolVal (left <> right), ann)
  52. | (Gt, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  53. Const (BoolVal (left > right), ann)
  54. | (Gt, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  55. Const (BoolVal (left > right), ann)
  56. | (Lt, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  57. Const (BoolVal (left < right), ann)
  58. | (Lt, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  59. Const (BoolVal (left < right), ann)
  60. | (Ge, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  61. Const (BoolVal (left >= right), ann)
  62. | (Ge, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  63. Const (BoolVal (left >= right), ann)
  64. | (Le, Const (IntVal left, _), Const (IntVal right, _), ann) ->
  65. Const (BoolVal (left <= right), ann)
  66. | (Le, Const (FloatVal left, _), Const (FloatVal right, _), ann) ->
  67. Const (BoolVal (left <= right), ann)
  68. (* Logical *)
  69. | (And, Const (BoolVal left, _), Const (BoolVal right, _), ann) ->
  70. Const (BoolVal (left && right), ann)
  71. | (Or, Const (BoolVal left, _), Const (BoolVal right, _), ann) ->
  72. Const (BoolVal (left || right), ann)
  73. | (op, left, right, ann) -> Binop (op, left, right, ann)
  74. let rec propagate consts node =
  75. let propagate = propagate consts in
  76. match node with
  77. (* Constant assignments are added to constants table *)
  78. | Assign (name, None, value, ann) when is_const_name name ->
  79. let value = propagate value in
  80. if is_const value then (
  81. Hashtbl.add consts name value;
  82. DummyNode
  83. ) else
  84. Assign (name, None, value, ann)
  85. | VarLet (dec, None, value, ann) when is_const_name (nameof dec) ->
  86. let value = propagate value in
  87. if is_const value then (
  88. Hashtbl.add consts (nameof dec) value;
  89. DummyNode
  90. ) else
  91. VarLet (dec, None, value, ann)
  92. (* Variables that are in the constant table are replaced with their constant
  93. * value *)
  94. | Var (name, None, ann) when Hashtbl.mem consts name ->
  95. Hashtbl.find consts name
  96. | VarUse (dec, None, ann) when Hashtbl.mem consts (nameof dec) ->
  97. Hashtbl.find consts (nameof dec)
  98. | Dim (name, ann) when Hashtbl.mem consts name ->
  99. Hashtbl.find consts name
  100. (* Apply arithmetic simplification to constant operands *)
  101. | Monop (op, opnd, ann) ->
  102. let opnd = propagate opnd in
  103. if is_const opnd
  104. then eval_monop (op, opnd, ann)
  105. else Monop (op, opnd, ann)
  106. | Binop (op, left, right, ann) ->
  107. let left = propagate left in
  108. let right = propagate right in
  109. if is_const left && is_const right
  110. then eval_binop (op, left, right, ann)
  111. else Binop (op, left, right, ann)
  112. | Cond (cond, texp, fexp, ann) ->
  113. let cond = propagate cond in
  114. let texp = propagate texp in
  115. let fexp = propagate fexp in
  116. (match cond with
  117. | Const (BoolVal value, _) -> if value then texp else fexp
  118. | _ -> Cond (cond, texp, fexp, ann)
  119. )
  120. | TypeCast (ctype, value, ann) ->
  121. let value = propagate value in
  122. (match (ctype, value) with
  123. | (Bool, Const (BoolVal value, _)) -> Const (BoolVal value, ann)
  124. | (Bool, Const (IntVal value, _)) -> Const (BoolVal (value != 1), ann)
  125. | (Bool, Const (FloatVal value, _)) -> Const (BoolVal (value != 1.0), ann)
  126. | (Int, Const (BoolVal value, _)) -> Const (IntVal (if value then 1 else 0), ann)
  127. | (Int, Const (IntVal value, _)) -> Const (IntVal value, ann)
  128. | (Int, Const (FloatVal value, _)) -> Const (IntVal (int_of_float value), ann)
  129. | (Float, Const (BoolVal value, _)) -> Const (FloatVal (if value then 1. else 0.), ann)
  130. | (Float, Const (IntVal value, _)) -> Const (FloatVal (float_of_int value), ann)
  131. | (Float, Const (FloatVal value, _)) -> Const (FloatVal value, ann)
  132. | _ -> TypeCast (ctype, value, ann)
  133. )
  134. | _ -> transform_children propagate node
  135. let rec prune_vardecs consts = function
  136. | VarDec (_, name, _, _) when Hashtbl.mem consts name -> DummyNode
  137. | node -> transform_children (prune_vardecs consts) node
  138. let phase = function
  139. | Ast node ->
  140. let consts = Hashtbl.create 32 in
  141. let node = propagate consts node in
  142. Ast (prune_vardecs consts node)
  143. | _ -> raise (InvalidInput "constant propagation")