constant_propagation.ml 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. (**
  2. * The compiler sometimes generates variables of the form foo$1, to make sure
  3. * that expressions are only executed once. In many cases, this leads to
  4. * over-complex constructions, for example when converting for-loops to
  5. * while-loops. We use the knowledge of these variables being constant by
  6. * propagation the constant values to their occurrences, and then apply
  7. * arithmetic simplification to operators to reduce the size and complexity of
  8. * the generated code. Note that this can only be applied to constants. For
  9. * variables in general, some form of liveness analysis would be required (e.g.
  10. * Static Single Assignment form). Expressions can only be propagated when they
  11. * have no side effects, i.e. when they do not contain function calls.
  12. *)
  13. open Ast
  14. open Util
  15. let is_const_name name =
  16. Str.string_match (Str.regexp "^.+\\$\\$[0-9]+$") name 0
  17. let is_const = function
  18. | BoolConst _ | IntConst _ | FloatConst _ -> true
  19. | _ -> false
  20. let eval_monop = function
  21. | (Not, BoolConst (value, _), ann) -> BoolConst (not value, ann)
  22. | (Neg, IntConst (value, _), ann) -> IntConst (-value, ann)
  23. | (Neg, FloatConst (value, _), ann) -> FloatConst (-.value, ann)
  24. | (op, opnd, ann) -> Monop (op, opnd, ann)
  25. let eval_binop = function
  26. (* Arithmetic *)
  27. | (Add, IntConst (left, _), IntConst (right, _), ann) ->
  28. IntConst (left + right, ann)
  29. | (Add, FloatConst (left, _), FloatConst (right, _), ann) ->
  30. FloatConst (left +. right, ann)
  31. | (Sub, IntConst (left, _), IntConst (right, _), ann) ->
  32. IntConst (left - right, ann)
  33. | (Sub, FloatConst (left, _), FloatConst (right, _), ann) ->
  34. FloatConst (left -. right, ann)
  35. | (Mul, IntConst (left, _), IntConst (right, _), ann) ->
  36. IntConst (left * right, ann)
  37. | (Mul, FloatConst (left, _), FloatConst (right, _), ann) ->
  38. FloatConst (left *. right, ann)
  39. | (Div, IntConst (left, _), IntConst (right, _), ann) ->
  40. IntConst (left / right, ann)
  41. | (Div, FloatConst (left, _), FloatConst (right, _), ann) ->
  42. FloatConst (left /. right, ann)
  43. | (Mod, IntConst (left, _), IntConst (right, _), ann) ->
  44. IntConst (left mod right, ann)
  45. (* Relational *)
  46. | (Eq, IntConst (left, _), IntConst (right, _), ann) ->
  47. BoolConst (left = right, ann)
  48. | (Eq, FloatConst (left, _), FloatConst (right, _), ann) ->
  49. BoolConst (left = right, ann)
  50. | (Ne, IntConst (left, _), IntConst (right, _), ann) ->
  51. BoolConst (left != right, ann)
  52. | (Ne, FloatConst (left, _), FloatConst (right, _), ann) ->
  53. BoolConst (left != right, ann)
  54. | (Gt, IntConst (left, _), IntConst (right, _), ann) ->
  55. BoolConst (left > right, ann)
  56. | (Gt, FloatConst (left, _), FloatConst (right, _), ann) ->
  57. BoolConst (left > right, ann)
  58. | (Lt, IntConst (left, _), IntConst (right, _), ann) ->
  59. BoolConst (left < right, ann)
  60. | (Lt, FloatConst (left, _), FloatConst (right, _), ann) ->
  61. BoolConst (left < right, ann)
  62. | (Ge, IntConst (left, _), IntConst (right, _), ann) ->
  63. BoolConst (left >= right, ann)
  64. | (Ge, FloatConst (left, _), FloatConst (right, _), ann) ->
  65. BoolConst (left >= right, ann)
  66. | (Le, IntConst (left, _), IntConst (right, _), ann) ->
  67. BoolConst (left <= right, ann)
  68. | (Le, FloatConst (left, _), FloatConst (right, _), ann) ->
  69. BoolConst (left <= right, ann)
  70. (* Logical *)
  71. | (And, BoolConst (left, _), BoolConst (right, _), ann) ->
  72. BoolConst (left && right, ann)
  73. | (Or, BoolConst (left, _), BoolConst (right, _), ann) ->
  74. BoolConst (left || right, ann)
  75. | (op, left, right, ann) -> Binop (op, left, right, ann)
  76. let rec propagate consts node =
  77. let propagate = propagate consts in
  78. match node with
  79. (* Constant assignments are added to constants table *)
  80. | Assign (name, None, value, ann) when is_const_name name ->
  81. let value = propagate value in
  82. if is_const value then (
  83. Hashtbl.add consts name value;
  84. DummyNode
  85. ) else
  86. Assign (name, None, value, ann)
  87. | VarLet (dec, None, value, ann) when is_const_name (nameof dec) ->
  88. let value = propagate value in
  89. if is_const value then (
  90. Hashtbl.add consts (nameof dec) value;
  91. DummyNode
  92. ) else
  93. VarLet (dec, None, value, ann)
  94. (* Variables that are in the constant table are replaced with their constant
  95. * value *)
  96. | Var (name, None, ann) when Hashtbl.mem consts name ->
  97. Hashtbl.find consts name
  98. | VarUse (dec, None, ann) when Hashtbl.mem consts (nameof dec) ->
  99. Hashtbl.find consts (nameof dec)
  100. | Dim (name, ann) when Hashtbl.mem consts name ->
  101. Hashtbl.find consts name
  102. (* Apply arithmetic simplification to constant operands *)
  103. | Monop (op, opnd, ann) ->
  104. let opnd = propagate opnd in
  105. if is_const opnd
  106. then eval_monop (op, opnd, ann)
  107. else Monop (op, opnd, ann)
  108. | Binop (op, left, right, ann) ->
  109. let left = propagate left in
  110. let right = propagate right in
  111. if is_const left && is_const right
  112. then eval_binop (op, left, right, ann)
  113. else Binop (op, left, right, ann)
  114. | Cond (cond, texp, fexp, ann) ->
  115. let cond = propagate cond in
  116. let texp = propagate texp in
  117. let fexp = propagate fexp in
  118. (match cond with
  119. | BoolConst (value, _) -> if value then texp else fexp
  120. | _ -> Cond (cond, texp, fexp, ann)
  121. )
  122. | TypeCast (ctype, value, ann) ->
  123. let value = propagate value in
  124. (match (ctype, value) with
  125. | (Bool, BoolConst (value, _)) -> BoolConst (value, ann)
  126. | (Bool, IntConst (value, _)) -> BoolConst (value != 1, ann)
  127. | (Bool, FloatConst (value, _)) -> BoolConst (value != 1.0, ann)
  128. | (Int, BoolConst (value, _)) -> IntConst ((if value then 1 else 0), ann)
  129. | (Int, IntConst (value, _)) -> IntConst (value, ann)
  130. | (Int, FloatConst (value, _)) -> IntConst (int_of_float value, ann)
  131. | (Float, BoolConst (value, _)) -> FloatConst ((if value then 1. else 0.), ann)
  132. | (Float, IntConst (value, _)) -> FloatConst (float_of_int value, ann)
  133. | (Float, FloatConst (value, _)) -> FloatConst (value, ann)
  134. | _ -> TypeCast (ctype, value, ann)
  135. )
  136. | _ -> transform_children propagate node
  137. let rec prune_vardecs consts = function
  138. | VarDec (ctype, name, init, ann) when Hashtbl.mem consts name -> DummyNode
  139. | node -> transform_children (prune_vardecs consts) node
  140. let rec phase input =
  141. log_line 2 "- Constant propagation";
  142. match input with
  143. | Ast node ->
  144. let consts = Hashtbl.create 32 in
  145. let node = propagate consts node in
  146. Ast (prune_vardecs consts node)
  147. | _ -> raise (InvalidInput "constant propagation")