|
|
@@ -0,0 +1,148 @@
|
|
|
+open Printf
|
|
|
+open Ast
|
|
|
+open Util
|
|
|
+open Stringify
|
|
|
+
|
|
|
+(*
|
|
|
+ * Do a number of checks:
|
|
|
+ * x A void function must not return a value.
|
|
|
+ * x A non-void function must return a value of the correct type.
|
|
|
+ * - Array indices must be of type integer.
|
|
|
+ * - The number of array indices must match the number of array dimensions.
|
|
|
+ * x The type on the right-hand side of an assignment must match the type on
|
|
|
+ * the left-hand side.
|
|
|
+ * - The number of arguments used for a function call must match the number of
|
|
|
+ * parameters for that function.
|
|
|
+ * - The types of the function arguments must match the types of parameters.
|
|
|
+ * x The operands of a unary or binary operation must have valid types.
|
|
|
+ * x The predicate expression of an if, while, or do-while statement must be
|
|
|
+ * a boolean.
|
|
|
+ * x Only values having a basic type can be type cast.
|
|
|
+ *)
|
|
|
+
|
|
|
+let check_type ?(msg="") expected = function
|
|
|
+ | Type (node, got) when got != expected ->
|
|
|
+ let msg = match msg with
|
|
|
+ | "" -> sprintf "expected type %s, got %s"
|
|
|
+ (type2str expected) (type2str got)
|
|
|
+ | _ -> msg
|
|
|
+ in
|
|
|
+ raise (NodeError (node, msg))
|
|
|
+ | Type _ -> ()
|
|
|
+ | _ -> raise InvalidNode
|
|
|
+
|
|
|
+let op_types = function
|
|
|
+ | Not | And | Or -> [Bool]
|
|
|
+ | Mod -> [Int]
|
|
|
+ | Neg | Sub | Div | Lt | Le | Gt | Ge -> [Int; Float]
|
|
|
+ | Add | Mul | Eq | Ne -> [Bool; Int; Float]
|
|
|
+
|
|
|
+let op_result_type operand_type = function
|
|
|
+ | Not | And | Or | Eq | Ne | Lt | Le | Gt | Ge -> Bool
|
|
|
+ | Neg | Add | Sub | Mul | Div | Mod -> operand_type
|
|
|
+
|
|
|
+(* Check if the given operator can be applied to the given type *)
|
|
|
+let check_type_op allowed_types desc = function
|
|
|
+ | Type (node, ctype) when not (List.mem ctype allowed_types) ->
|
|
|
+ let msg =
|
|
|
+ sprintf "%s cannot be applied to type %s, only to %s"
|
|
|
+ desc (type2str ctype) (types2str allowed_types)
|
|
|
+ in
|
|
|
+ raise (NodeError (node, msg))
|
|
|
+ | Type _ -> ()
|
|
|
+ | _ -> raise InvalidNode
|
|
|
+
|
|
|
+let rec typecheck node = match node with
|
|
|
+ | BoolConst (value, _) -> Type (node, Bool)
|
|
|
+ | IntConst (value, _) -> Type (node, Int)
|
|
|
+ | FloatConst (value, _) -> Type (node, Float)
|
|
|
+ | VarUse (_, ctype, _) -> Type (node, ctype)
|
|
|
+ | FunUse (_, ret_type, _) -> Type (node, ret_type)
|
|
|
+
|
|
|
+ | Monop (op, (Type (_, vtype) as value), _) ->
|
|
|
+ let desc = sprintf "unary operator \"%s\"" (op2str op) in
|
|
|
+ check_type_op (op_types op) desc value;
|
|
|
+ Type (node, op_result_type vtype op)
|
|
|
+ | Monop (op, value, loc) ->
|
|
|
+ typecheck (Monop (op, typecheck value, loc))
|
|
|
+
|
|
|
+ | Binop (op, (Type (_, ltype) as left), right, loc) ->
|
|
|
+ let desc = sprintf "binary operator \"%s\"" (op2str op) in
|
|
|
+ check_type_op (op_types op) desc left;
|
|
|
+ check_type ltype right;
|
|
|
+ Type (node, op_result_type ltype op)
|
|
|
+ | Binop (op, left, right, loc) ->
|
|
|
+ typecheck (Binop (op, typecheck left, typecheck right, loc))
|
|
|
+
|
|
|
+ | Cond (Type (cond, condtype), Type (texpr, ttype), fexpr, loc) ->
|
|
|
+ check_type ttype fexpr;
|
|
|
+ Type (node, ttype)
|
|
|
+
|
|
|
+ | VarLet (_, (Type (_, vtype) as value), dec_type, depth) ->
|
|
|
+ check_type dec_type value;
|
|
|
+ Type (node, vtype)
|
|
|
+ | VarLet (assign, value, dec_type, depth) ->
|
|
|
+ typecheck (VarLet (assign, typecheck value, dec_type, depth))
|
|
|
+
|
|
|
+ | TypeCast (ctype, (Type _ as value), loc) ->
|
|
|
+ check_type_op [Bool; Int; Float] "typecast" value;
|
|
|
+ Type (node, ctype)
|
|
|
+ | TypeCast (ctype, value, loc) ->
|
|
|
+ typecheck (TypeCast (ctype, typecheck value, loc))
|
|
|
+
|
|
|
+ | Return (Type _, _) ->
|
|
|
+ node
|
|
|
+ | Return (value, loc) ->
|
|
|
+ typecheck (Return (typecheck value, loc))
|
|
|
+
|
|
|
+ | FunDef (export, ret_type, name, params, body, loc) ->
|
|
|
+ let params = transform_all typecheck params in
|
|
|
+ let body = typecheck body in
|
|
|
+ let rec find_return = function
|
|
|
+ | [] -> None
|
|
|
+ | [Return (Type (_, rtype), _) as ret] -> Some (ret, rtype)
|
|
|
+ | hd :: tl -> find_return tl
|
|
|
+ in (
|
|
|
+ match (ret_type, find_return (block_body body)) with
|
|
|
+ | (Void, Some (ret, _)) ->
|
|
|
+ raise (NodeError (ret, "void function should not have a return value"))
|
|
|
+ | ((Bool | Int | Float), None) ->
|
|
|
+ let msg =
|
|
|
+ sprintf "expected return value of type %s for function \"%s\""
|
|
|
+ (type2str ret_type) name
|
|
|
+ in
|
|
|
+ raise (NodeError (node, msg))
|
|
|
+ | ((Bool | Int | Float), Some (ret, t)) when t != ret_type ->
|
|
|
+ let msg =
|
|
|
+ sprintf "function \"%s\" has return type %s, got %s"
|
|
|
+ name (type2str ret_type) (type2str t)
|
|
|
+ in
|
|
|
+ raise (NodeError (ret, msg))
|
|
|
+ | _ ->
|
|
|
+ FunDef (export, ret_type, name, params, body, loc)
|
|
|
+ )
|
|
|
+
|
|
|
+ (* Conditions in if-statements and loop must be type bool *)
|
|
|
+ | If (Type _ as cond, _, _)
|
|
|
+ | IfElse (Type _ as cond, _, _, _)
|
|
|
+ | While (Type _ as cond, _, _)
|
|
|
+ | DoWhile (Type _ as cond, _, _) ->
|
|
|
+ check_type Bool cond (*~msg:"condition should have type bool"*);
|
|
|
+ node
|
|
|
+ | If (cond, body, loc) ->
|
|
|
+ typecheck (If (typecheck cond, typecheck body, loc))
|
|
|
+ | IfElse (cond, tbody, fbody, loc) ->
|
|
|
+ typecheck (IfElse (typecheck cond, typecheck tbody, typecheck fbody, loc))
|
|
|
+ | While (cond, body, loc) ->
|
|
|
+ typecheck (While (typecheck cond, typecheck body, loc))
|
|
|
+ | DoWhile (cond, body, loc) ->
|
|
|
+ typecheck (DoWhile (typecheck cond, typecheck body, loc))
|
|
|
+
|
|
|
+ | _ -> transform_children typecheck node
|
|
|
+
|
|
|
+let rec phase input =
|
|
|
+ prerr_endline "- Type checking";
|
|
|
+ match input with
|
|
|
+ | Ast (node, args) ->
|
|
|
+ Ast (typecheck node, args)
|
|
|
+ | _ -> raise (InvalidInput "typecheck")
|