(* * Do a number of checks: * - A void function must not return a value. * - A non-void function must return a value of the correct type. * - Array indices must be of type integer. * - The number of array indices must match the number of array dimensions. * - The type on the right-hand side of an assignment must match the type on * the left-hand side. * - The number of arguments used for a function call must match the number of * parameters for that function. * - The types of the function arguments must match the types of parameters. * - The operands of a unary or binary operation must have valid types. * - The predicate expression of an if, while, or do-while statement must be * a boolean. * - Only values having a basic type can be type cast. *) open Printf open Types open Util open Stringify let array_depth = function | ArrayDims (_, dims) -> List.length dims | _ -> raise InvalidNode let spec = function | ArrayDims (ctype, dims) -> (ctype, List.length dims) | ctype -> (ctype, 0) let check_type ?(msg="") expected node = let got = typeof node in if (spec got) <> (spec expected) then begin let msg = match msg with | "" -> sprintf "type mismatch: expected type %s, got %s" (type2str expected) (type2str got) (*(type2str (spec expected)) (type2str (spec got))*) | _ -> msg in raise (NodeError (node, msg)) end let op_types = function | Not | And | Or -> [Bool] | Mod -> [Int] | Neg | Sub | Div | Lt | Le | Gt | Ge -> [Int; Float] | Add | Mul | Eq | Ne -> [Bool; Int; Float] let op_result_type opnd_type = function | Not | And | Or | Eq | Ne | Lt | Le | Gt | Ge -> Bool | Neg | Add | Sub | Mul | Div | Mod -> opnd_type (* Check if the given operator can be applied to the given type *) let check_type_op allowed_types desc node = let got = typeof node in if not (List.mem got allowed_types) then ( let msg = sprintf "%s cannot be applied to type %s, only to %s" desc (type2str got) (types2str allowed_types) in raise (NodeError (node, msg)) ); () let check_dims_match dims dec_type errnode = match (List.length dims, array_depth dec_type) with | (got, expected) when got != expected -> let msg = sprintf "dimension mismatch: expected %d indices, got %d" expected got in raise (NodeError (errnode, msg)) | _ -> () let rec typecheck node = let check_trav ctype node = let node = typecheck node in check_type ctype node; node in match node with | FunUse ((FunDec (ret_type, name, params, _) as dec), args, ann) | FunUse ((FunDef (_, ret_type, name, params, _, _) as dec), args, ann) -> begin match (List.length args, List.length params) with | (nargs, nparams) when nargs != nparams -> let msg = sprintf "function \"%s\" expects %d arguments, got %d" name nparams nargs in raise (NodeError (node, msg)) | _ -> let args = List.map typecheck args in let check_arg_type arg param = check_type (typeof param) arg; in List.iter2 check_arg_type args params; FunUse (dec, args, Type ret_type :: ann) end (* Operators match operand types and get a new type based on the operator *) | Monop (op, opnd, ann) -> let opnd = typecheck opnd in let desc = sprintf "unary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc opnd; Monop (op, opnd, Type (op_result_type (typeof opnd) op) :: ann) | Binop (op, left, right, ann) -> let left = typecheck left in let right = typecheck right in let desc = sprintf "binary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc left; check_type (typeof left) right; (* Check for division by zero *) begin match (op, right) with | (Div, Const (IntVal 0, _)) -> node_warning right "division by zero" | _ -> () end; Binop (op, left, right, Type (op_result_type (typeof left) op) :: ann) (* Conditions must be bool, and right-hand type must match left-hand type *) | Cond (cond, texpr, fexpr, ann) -> let cond = check_trav Bool cond in let texpr = typecheck texpr in let fexpr = check_trav (typeof texpr) fexpr in Cond (cond, texpr, fexpr, Type (typeof texpr) :: ann) (* Only basic types can be typecasted *) | TypeCast (ctype, value, ann) -> let value = typecheck value in check_type_op [Bool; Int; Float] "typecast" value; TypeCast (ctype, value, Type (ctype) :: ann) (* Array allocation dimensions must have type int *) | Allocate (dec, dims, ann) -> Allocate (dec, List.map (check_trav Int) dims, ann) (* Array dimensions are always integers *) | Dim (name, ann) -> Dim (name, Type Int :: ann) (* Functions and parameters must be traversed to give types to Dim nodes *) (* | FunDec (ret_type, name, params, ann) -> FunDec (ret_type, name, List.map typecheck params, ann) | Param (ArrayDims (ctype, dims), name, ann) -> Param (ArrayDims (ctype, List.map typecheck dims), name, ann) *) (* Void functions may have no return statement, other functions must have a * return statement of valid type *) | FunDef (export, ret_type, name, params, body, ann) -> let params = List.map typecheck params in let body = typecheck body in let rec find_return = function | [] -> None | [Return (value, _) as ret] -> Some (ret, typeof value) | hd :: tl -> find_return tl in begin match (ret_type, find_return (block_body body)) with | (Void, Some (ret, _)) -> raise (NodeError (ret, "void function should not have a return value")) | ((Bool | Int | Float), None) -> let msg = sprintf "expected return value of type %s for function \"%s\"" (type2str ret_type) name in raise (NodeError (node, msg)) | ((Bool | Int | Float), Some (ret, t)) when t != ret_type -> let msg = sprintf "function \"%s\" has return type %s, got %s" name (type2str ret_type) (type2str t) in raise (NodeError (ret, msg)) | _ -> FunDef (export, ret_type, name, params, body, ann) end (* Conditions in must have type bool *) | If (cond, body, ann) -> If (check_trav Bool cond, typecheck body, ann) | IfElse (cond, tbody, fbody, ann) -> IfElse (check_trav Bool cond, typecheck tbody, typecheck fbody, ann) | While (cond, body, ann) -> While (check_trav Bool cond, typecheck body, ann) | DoWhile (cond, body, ann) -> DoWhile (check_trav Bool cond, typecheck body, ann) (* Constants *) | Const (BoolVal value, ann) -> Const (BoolVal value, Type Bool :: ann) | Const (IntVal value, ann) -> (* Do a bound check on integers (use Int32 because default ints in ocaml * are 31- or 63-bit *) let cmpval = Nativeint.of_int value in let min = Nativeint.of_int32 Int32.min_int in let max = Nativeint.of_int32 Int32.max_int in if cmpval < min || cmpval > max then ( raise (NodeError (node, "integer value out of range (signed 32-bit)")) ); Const (IntVal value, Type Int :: ann) | Const (FloatVal value, ann) -> Const (FloatVal value, Type Float :: ann) (* Variables inherit the type of their declaration *) | VarUse (dec, None, ann) -> VarUse (dec, None, Type (typeof dec) :: ann) | VarUse (dec, Some dims, ann) -> let dims = List.map typecheck dims in List.iter (check_type Int) dims; check_dims_match dims (typeof dec) node; VarUse (dec, Some dims, Type (basetypeof dec) :: ann) (* Array pointers cannot be re-assigned, because array dimension reduction * makes assumptions about dimensions of an array *) | VarLet (dec, None, _, _) when is_array dec -> raise (NodeError (node, "cannot assign value to array pointer after \ initialisation")) (* Assigned values must match variable declaration *) | VarLet (dec, None, value, ann) -> VarLet (dec, None, check_trav (typeof dec) value, ann) | VarLet (dec, Some dims, value, ann) -> (* Number of assigned indices must match array definition *) check_dims_match dims (typeof dec) node; (* Array indices must be ints *) let dims = List.map typecheck dims in List.iter (check_type Int) dims; (* Assigned value must match array base type *) let value = typecheck value in check_type (basetypeof dec) value; VarLet (dec, Some dims, value, ann) (* ArrayConst initialisations are transformed during desugaring, so any * occurrences that are left are illegal *) | ArrayConst _ -> raise (NodeError (node, "array constants can only be used in array \ initialisation")) | _ -> transform_children typecheck node let phase = function | Ast node -> Ast (typecheck node) | _ -> raise (InvalidInput "typecheck")