typecheck.ml 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. open Printf
  2. open Ast
  3. open Util
  4. open Stringify
  5. (*
  6. * Do a number of checks:
  7. * - A void function must not return a value.
  8. * - A non-void function must return a value of the correct type.
  9. * - Array indices must be of type integer.
  10. * - The number of array indices must match the number of array dimensions.
  11. * - The type on the right-hand side of an assignment must match the type on
  12. * the left-hand side.
  13. * - The number of arguments used for a function call must match the number of
  14. * parameters for that function.
  15. * - The types of the function arguments must match the types of parameters.
  16. * - The operands of a unary or binary operation must have valid types.
  17. * - The predicate expression of an if, while, or do-while statement must be
  18. * a boolean.
  19. * - Only values having a basic type can be type cast.
  20. *)
  21. let spec = function
  22. | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims)
  23. | ctype -> ctype
  24. let check_type ?(msg="") expected = function
  25. | Type (node, got) when (spec got) <> (spec expected) ->
  26. let msg = match msg with
  27. | "" -> sprintf "type mismatch: expected type %s, got %s"
  28. (type2str expected) (type2str got)
  29. (*(type2str (spec expected)) (type2str (spec got))*)
  30. | _ -> msg
  31. in
  32. raise (NodeError (node, msg))
  33. | Type _ -> ()
  34. | _ -> raise InvalidNode
  35. let op_types = function
  36. | Not | And | Or -> [Bool]
  37. | Mod -> [Int]
  38. | Neg | Sub | Div | Lt | Le | Gt | Ge -> [Int; Float]
  39. | Add | Mul | Eq | Ne -> [Bool; Int; Float]
  40. let op_result_type operand_type = function
  41. | Not | And | Or | Eq | Ne | Lt | Le | Gt | Ge -> Bool
  42. | Neg | Add | Sub | Mul | Div | Mod -> operand_type
  43. (* Check if the given operator can be applied to the given type *)
  44. let check_type_op allowed_types desc = function
  45. | Type (node, ctype) when not (List.mem ctype allowed_types) ->
  46. let msg = sprintf
  47. "%s cannot be applied to type %s, only to %s"
  48. desc (type2str ctype) (types2str allowed_types)
  49. in
  50. raise (NodeError (node, msg))
  51. | Type _ -> ()
  52. | _ -> raise InvalidNode
  53. let check_dims_match dims dec_type errnode =
  54. match (list_size dims, array_depth dec_type) with
  55. | (got, expected) when got != expected ->
  56. let msg = sprintf
  57. "dimension mismatch: expected %d indices, got %d" expected got
  58. in
  59. raise (NodeError (errnode, msg))
  60. | _ -> ()
  61. let rec typecheck node = match node with
  62. | FunUse (FunCall (fname, args, floc),
  63. (FunDec (ftype, name, params, _) as dec), loc)
  64. | FunUse (FunCall (fname, args, floc),
  65. (FunDef (_, ftype, name, params, _, _) as dec), loc) ->
  66. (match (list_size args, list_size params) with
  67. | (nargs, nparams) when nargs != nparams ->
  68. let msg = sprintf
  69. "function \"%s\" expects %d arguments, got %d"
  70. name nparams nargs
  71. in
  72. raise (NodeError (node, msg))
  73. | _ ->
  74. let args = List.map typecheck args in
  75. let check_arg_type arg param =
  76. check_type (ctypeof param) arg;
  77. in
  78. List.iter2 check_arg_type args params;
  79. Type (FunUse (FunCall (fname, args, floc), dec, loc), ftype)
  80. )
  81. | Arg (Type (_, vtype)) -> Type (node, vtype)
  82. | Arg value -> typecheck (Arg (typecheck value))
  83. | Monop (op, (Type (_, vtype) as value), _) ->
  84. let desc = sprintf "unary operator \"%s\"" (op2str op) in
  85. check_type_op (op_types op) desc value;
  86. Type (node, op_result_type vtype op)
  87. | Monop (op, value, loc) ->
  88. typecheck (Monop (op, typecheck value, loc))
  89. | Binop (op, (Type (_, ltype) as left), right, loc) ->
  90. let desc = sprintf "binary operator \"%s\"" (op2str op) in
  91. check_type_op (op_types op) desc left;
  92. check_type ltype right;
  93. Type (node, op_result_type ltype op)
  94. | Binop (op, left, right, loc) ->
  95. typecheck (Binop (op, typecheck left, typecheck right, loc))
  96. | Cond (Type (cond, condtype), Type (texpr, ttype), fexpr, loc) ->
  97. check_type ttype fexpr;
  98. Type (node, ttype)
  99. | Cond (cond, texpr, fexpr, loc) ->
  100. typecheck (Cond (typecheck cond, typecheck texpr, typecheck fexpr, loc))
  101. | VarLet (Assign (_, None, (Type _ as value), _), dec_type, depth) ->
  102. check_type dec_type value;
  103. node
  104. | VarLet (Assign (_, Some dims, (Type _ as value), _) as assign, dec_type, depth) ->
  105. (* Number of assigned indices must match array definition *)
  106. check_dims_match dims dec_type assign;
  107. (* Array indices must be ints *)
  108. List.iter (check_type Int) dims;
  109. (* Assigned value must match array base type *)
  110. check_type (base_type dec_type) value;
  111. node
  112. | VarLet (assign, dec_type, depth) ->
  113. typecheck (VarLet (typecheck assign, dec_type, depth))
  114. | TypeCast (ctype, (Type _ as value), loc) ->
  115. check_type_op [Bool; Int; Float] "typecast" value;
  116. Type (node, ctype)
  117. | TypeCast (ctype, value, loc) ->
  118. typecheck (TypeCast (ctype, typecheck value, loc))
  119. | VarUse (Deref (_, dims, _) as deref, dec_type, depth) ->
  120. let dims = List.map typecheck dims in
  121. List.iter (check_type Int) dims;
  122. check_dims_match dims dec_type deref;
  123. typecheck (VarUse (Type (deref, base_type dec_type), dec_type, depth))
  124. | VarUse (Type (_, ctype), _, _)
  125. | VarUse (_, ctype, _) ->
  126. Type (node, ctype)
  127. | Allocate (name, dims, dec, loc) ->
  128. let dims = List.map typecheck dims in
  129. List.iter (check_type Int) dims;
  130. Allocate (name, dims, dec, loc)
  131. | Return (Type _, _) -> node
  132. | Return (value, loc) -> typecheck (Return (typecheck value, loc))
  133. | FunDef (export, ret_type, name, params, body, loc) ->
  134. let params = transform_all typecheck params in
  135. let body = typecheck body in
  136. let rec find_return = function
  137. | [] -> None
  138. | [Return (Type (_, rtype), _) as ret] -> Some (ret, rtype)
  139. | hd :: tl -> find_return tl
  140. in (
  141. match (ret_type, find_return (block_body body)) with
  142. | (Void, Some (ret, _)) ->
  143. raise (NodeError (ret, "void function should not have a return value"))
  144. | ((Bool | Int | Float), None) ->
  145. let msg = sprintf
  146. "expected return value of type %s for function \"%s\""
  147. (type2str ret_type) name
  148. in
  149. raise (NodeError (node, msg))
  150. | ((Bool | Int | Float), Some (ret, t)) when t != ret_type ->
  151. let msg = sprintf
  152. "function \"%s\" has return type %s, got %s"
  153. name (type2str ret_type) (type2str t)
  154. in
  155. raise (NodeError (ret, msg))
  156. | _ ->
  157. FunDef (export, ret_type, name, params, body, loc)
  158. )
  159. (* Conditions in if-statements and loop must be type bool *)
  160. | If (Type _ as cond, _, _)
  161. | IfElse (Type _ as cond, _, _, _)
  162. | While (Type _ as cond, _, _)
  163. | DoWhile (Type _ as cond, _, _) ->
  164. check_type Bool cond (*~msg:"condition should have type bool"*);
  165. node
  166. | If (cond, body, loc) ->
  167. typecheck (If (typecheck cond, typecheck body, loc))
  168. | IfElse (cond, tbody, fbody, loc) ->
  169. typecheck (IfElse (typecheck cond, typecheck tbody, typecheck fbody, loc))
  170. | While (cond, body, loc) ->
  171. typecheck (While (typecheck cond, typecheck body, loc))
  172. | DoWhile (cond, body, loc) ->
  173. typecheck (DoWhile (typecheck cond, typecheck body, loc))
  174. | BoolConst (value, _) -> Type (node, Bool)
  175. | IntConst (value, _) -> Type (node, Int)
  176. | FloatConst (value, _) -> Type (node, Float)
  177. | _ -> transform_children typecheck node
  178. let rec phase input =
  179. prerr_endline "- Type checking";
  180. match input with
  181. | Ast node -> Ast (typecheck node)
  182. | _ -> raise (InvalidInput "typecheck")