open Printf open Ast open Util open Stringify (* * Do a number of checks: * - A void function must not return a value. * - A non-void function must return a value of the correct type. * - Array indices must be of type integer. * - The number of array indices must match the number of array dimensions. * - The type on the right-hand side of an assignment must match the type on * the left-hand side. * - The number of arguments used for a function call must match the number of * parameters for that function. * - The types of the function arguments must match the types of parameters. * - The operands of a unary or binary operation must have valid types. * - The predicate expression of an if, while, or do-while statement must be * a boolean. * - Only values having a basic type can be type cast. *) let spec = function | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims) | ctype -> ctype let check_type ?(msg="") expected = function | Type (node, got) when (spec got) <> (spec expected) -> let msg = match msg with | "" -> sprintf "type mismatch: expected type %s, got %s" (type2str expected) (type2str got) (*(type2str (spec expected)) (type2str (spec got))*) | _ -> msg in raise (NodeError (node, msg)) | Type _ -> () | _ -> raise InvalidNode let op_types = function | Not | And | Or -> [Bool] | Mod -> [Int] | Neg | Sub | Div | Lt | Le | Gt | Ge -> [Int; Float] | Add | Mul | Eq | Ne -> [Bool; Int; Float] let op_result_type operand_type = function | Not | And | Or | Eq | Ne | Lt | Le | Gt | Ge -> Bool | Neg | Add | Sub | Mul | Div | Mod -> operand_type (* Check if the given operator can be applied to the given type *) let check_type_op allowed_types desc = function | Type (node, ctype) when not (List.mem ctype allowed_types) -> let msg = sprintf "%s cannot be applied to type %s, only to %s" desc (type2str ctype) (types2str allowed_types) in raise (NodeError (node, msg)) | Type _ -> () | _ -> raise InvalidNode let check_dims_match dims dec_type errnode = match (list_size dims, array_depth dec_type) with | (got, expected) when got != expected -> let msg = sprintf "dimension mismatch: expected %d indices, got %d" expected got in raise (NodeError (errnode, msg)) | _ -> () let rec typecheck node = match node with | FunUse (FunCall (_, args, _), FunDef (_, ftype, name, params, _, _), _) -> (match (list_size args, list_size params) with | (nargs, nparams) when nargs != nparams -> let msg = sprintf "function \"%s\" expects %d arguments, got %d" name nparams nargs in raise (NodeError (node, msg)) | _ -> let check_arg_type arg param = check_type (ctypeof param) (typecheck arg); in List.iter2 check_arg_type args params; Type (node, ftype) ) | Arg (Type (_, vtype)) -> Type (node, vtype) | Arg value -> typecheck (Arg (typecheck value)) | Monop (op, (Type (_, vtype) as value), _) -> let desc = sprintf "unary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc value; Type (node, op_result_type vtype op) | Monop (op, value, loc) -> typecheck (Monop (op, typecheck value, loc)) | Binop (op, (Type (_, ltype) as left), right, loc) -> let desc = sprintf "binary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc left; check_type ltype right; Type (node, op_result_type ltype op) | Binop (op, left, right, loc) -> typecheck (Binop (op, typecheck left, typecheck right, loc)) | Cond (Type (cond, condtype), Type (texpr, ttype), fexpr, loc) -> check_type ttype fexpr; Type (node, ttype) | VarLet (Assign (_, None, (Type _ as value), _), dec_type, depth) -> check_type dec_type value; node | VarLet (Assign (_, Some dims, (Type _ as value), _) as assign, dec_type, depth) -> (* Number of assigned indices must match array definition *) check_dims_match dims dec_type assign; (* Array indices must be ints *) List.iter (check_type Int) dims; (* Assigned value must match array base type *) check_type (base_type dec_type) value; node | VarLet (assign, dec_type, depth) -> typecheck (VarLet (typecheck assign, dec_type, depth)) | TypeCast (ctype, (Type _ as value), loc) -> check_type_op [Bool; Int; Float] "typecast" value; Type (node, ctype) | TypeCast (ctype, value, loc) -> typecheck (TypeCast (ctype, typecheck value, loc)) | VarUse (Deref (_, dims, _) as deref, dec_type, depth) -> let dims = List.map typecheck dims in List.iter (check_type Int) dims; check_dims_match dims dec_type deref; typecheck (VarUse (Type (deref, base_type dec_type), dec_type, depth)) | VarUse (Type (_, ctype), _, _) | VarUse (_, ctype, _) -> Type (node, ctype) | Allocate (name, dims, dec, loc) -> let dims = List.map typecheck dims in List.iter (check_type Int) dims; Allocate (name, dims, dec, loc) | Return (Type _, _) -> node | Return (value, loc) -> typecheck (Return (typecheck value, loc)) | FunDef (export, ret_type, name, params, body, loc) -> let params = transform_all typecheck params in let body = typecheck body in let rec find_return = function | [] -> None | [Return (Type (_, rtype), _) as ret] -> Some (ret, rtype) | hd :: tl -> find_return tl in ( match (ret_type, find_return (block_body body)) with | (Void, Some (ret, _)) -> raise (NodeError (ret, "void function should not have a return value")) | ((Bool | Int | Float), None) -> let msg = sprintf "expected return value of type %s for function \"%s\"" (type2str ret_type) name in raise (NodeError (node, msg)) | ((Bool | Int | Float), Some (ret, t)) when t != ret_type -> let msg = sprintf "function \"%s\" has return type %s, got %s" name (type2str ret_type) (type2str t) in raise (NodeError (ret, msg)) | _ -> FunDef (export, ret_type, name, params, body, loc) ) (* Conditions in if-statements and loop must be type bool *) | If (Type _ as cond, _, _) | IfElse (Type _ as cond, _, _, _) | While (Type _ as cond, _, _) | DoWhile (Type _ as cond, _, _) -> check_type Bool cond (*~msg:"condition should have type bool"*); node | If (cond, body, loc) -> typecheck (If (typecheck cond, typecheck body, loc)) | IfElse (cond, tbody, fbody, loc) -> typecheck (IfElse (typecheck cond, typecheck tbody, typecheck fbody, loc)) | While (cond, body, loc) -> typecheck (While (typecheck cond, typecheck body, loc)) | DoWhile (cond, body, loc) -> typecheck (DoWhile (typecheck cond, typecheck body, loc)) | BoolConst (value, _) -> Type (node, Bool) | IntConst (value, _) -> Type (node, Int) | FloatConst (value, _) -> Type (node, Float) | _ -> transform_children typecheck node let rec phase input = prerr_endline "- Type checking"; match input with | Ast (node, args) -> Ast (typecheck node, args) | _ -> raise (InvalidInput "typecheck")