open Printf open Types open Util open Stringify (* * Do a number of checks: * - A void function must not return a value. * - A non-void function must return a value of the correct type. * - Array indices must be of type integer. * - The number of array indices must match the number of array dimensions. * - The type on the right-hand side of an assignment must match the type on * the left-hand side. * - The number of arguments used for a function call must match the number of * parameters for that function. * - The types of the function arguments must match the types of parameters. * - The operands of a unary or binary operation must have valid types. * - The predicate expression of an if, while, or do-while statement must be * a boolean. * - Only values having a basic type can be type cast. *) let spec = function | Array (ctype, dims) -> ArrayDepth (ctype, list_size dims) | ctype -> ctype let check_type ?(msg="") expected node = let got = typeof node in if (spec got) <> (spec expected) then ( let msg = match msg with | "" -> sprintf "type mismatch: expected type %s, got %s" (type2str expected) (type2str got) (*(type2str (spec expected)) (type2str (spec got))*) | _ -> msg in raise (NodeError (node, msg)) ); () let op_types = function | Not | And | Or -> [Bool] | Mod -> [Int] | Neg | Sub | Div | Lt | Le | Gt | Ge -> [Int; Float] | Add | Mul | Eq | Ne -> [Bool; Int; Float] let op_result_type opnd_type = function | Not | And | Or | Eq | Ne | Lt | Le | Gt | Ge -> Bool | Neg | Add | Sub | Mul | Div | Mod -> opnd_type (* Check if the given operator can be applied to the given type *) let check_type_op allowed_types desc node = let got = typeof node in if not (List.mem got allowed_types) then ( let msg = sprintf "%s cannot be applied to type %s, only to %s" desc (type2str got) (types2str allowed_types) in raise (NodeError (node, msg)) ); () let check_dims_match dims dec_type errnode = match (list_size dims, array_depth dec_type) with | (got, expected) when got != expected -> let msg = sprintf "dimension mismatch: expected %d indices, got %d" expected got in raise (NodeError (errnode, msg)) | _ -> () let rec typecheck node = let check_trav ctype node = let node = typecheck node in check_type ctype node; node in match node with | FunUse ((FunDec (ret_type, name, params, _) as dec), args, ann) | FunUse ((FunDef (_, ret_type, name, params, _, _) as dec), args, ann) -> (match (list_size args, list_size params) with | (nargs, nparams) when nargs != nparams -> let msg = sprintf "function \"%s\" expects %d arguments, got %d" name nparams nargs in raise (NodeError (node, msg)) | _ -> let args = List.map typecheck args in let check_arg_type arg param = check_type (typeof param) arg; in List.iter2 check_arg_type args params; FunUse (dec, args, Type ret_type :: ann) ) (* Operators match operand types and get a new type based on the operator *) | Monop (op, opnd, ann) -> let opnd = typecheck opnd in let desc = sprintf "unary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc opnd; Monop (op, opnd, Type (op_result_type (typeof opnd) op) :: ann) | Binop (op, left, right, ann) -> let left = typecheck left in let right = typecheck right in let desc = sprintf "binary operator \"%s\"" (op2str op) in check_type_op (op_types op) desc left; check_type (typeof left) right; Binop (op, left, right, Type (op_result_type (typeof left) op) :: ann) (* Conditions must be bool, and right-hand type must match left-hand type *) | Cond (cond, texpr, fexpr, ann) -> let cond = check_trav Bool cond in let texpr = typecheck texpr in let fexpr = check_trav (typeof texpr) fexpr in Cond (cond, texpr, fexpr, Type (typeof texpr) :: ann) (* Only basic types can be typecasted *) | TypeCast (ctype, value, ann) -> let value = typecheck value in check_type_op [Bool; Int; Float] "typecast" value; TypeCast (ctype, value, Type (typeof value) :: ann) (* Array allocation dimensions must have type int *) | Allocate (dec, dims, ann) -> Allocate (dec, List.map (check_trav Int) dims, ann) (* Array dimensions are always integers *) | Dim (name, ann) -> Dim (name, Type Int :: ann) (* Functions and parameters must be traversed to give types to Dim nodes *) | FunDec (ret_type, name, params, ann) -> FunDec (ret_type, name, List.map typecheck params, ann) | Param (Array (ctype, dims), name, ann) -> Param (Array (ctype, List.map typecheck dims), name, ann) (* Void functions may have no return statement, other functions must have a * return statement of valid type *) | FunDef (export, ret_type, name, params, body, ann) -> let params = List.map typecheck params in let body = typecheck body in let rec find_return = function | [] -> None | [Return (value, _) as ret] -> Some (ret, typeof value) | hd :: tl -> find_return tl in ( match (ret_type, find_return (block_body body)) with | (Void, Some (ret, _)) -> raise (NodeError (ret, "void function should not have a return value")) | ((Bool | Int | Float), None) -> let msg = sprintf "expected return value of type %s for function \"%s\"" (type2str ret_type) name in raise (NodeError (node, msg)) | ((Bool | Int | Float), Some (ret, t)) when t != ret_type -> let msg = sprintf "function \"%s\" has return type %s, got %s" name (type2str ret_type) (type2str t) in raise (NodeError (ret, msg)) | _ -> FunDef (export, ret_type, name, params, body, ann) ) (* Conditions in must have type bool *) | If (cond, body, ann) -> If (check_trav Bool cond, typecheck body, ann) | IfElse (cond, tbody, fbody, ann) -> IfElse (check_trav Bool cond, typecheck tbody, typecheck fbody, ann) | While (cond, body, ann) -> While (check_trav Bool cond, typecheck body, ann) | DoWhile (cond, body, ann) -> DoWhile (check_trav Bool cond, typecheck body, ann) (* Constants *) | Const (BoolVal value, ann) -> Const (BoolVal value, Type Bool :: ann) | Const (IntVal value, ann) -> Const (IntVal value, Type Int :: ann) | Const (FloatVal value, ann) -> Const (FloatVal value, Type Float :: ann) (* Variables inherit the type of their declaration *) | VarUse (dec, None, ann) -> VarUse (dec, None, Type (typeof dec) :: ann) | VarUse (dec, Some dims, ann) -> let dims = List.map typecheck dims in List.iter (check_type Int) dims; check_dims_match dims (typeof dec) node; VarUse (dec, Some dims, Type (basetypeof dec) :: ann) (* Assigned values must match variable declaration *) | VarLet (dec, None, value, ann) -> VarLet (dec, None, check_trav (typeof dec) value, ann) | VarLet (dec, Some dims, value, ann) -> (* Number of assigned indices must match array definition *) check_dims_match dims (typeof dec) node; (* Array indices must be ints *) let dims = List.map typecheck dims in List.iter (check_type Int) dims; (* Assigned value must match array base type *) let value = typecheck value in check_type (basetypeof dec) value; VarLet (dec, Some dims, value, ann) | _ -> transform_children typecheck node let rec phase input = log_line 1 "- Type checking"; match input with | Ast node -> Ast (typecheck node) | _ -> raise (InvalidInput "typecheck")