expand_dims.ml 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. open Ast
  2. open Util
  3. let rec expand_dims = function
  4. (* Flatten Block nodes returned by transformations below *)
  5. | FunDef (export, ret_type, name, params, body, loc) ->
  6. let params = flatten_blocks (List.map expand_dims params) in
  7. FunDef (export, ret_type, name, params, expand_dims body, loc)
  8. | FunDec (ret_type, name, params, loc) ->
  9. let params = flatten_blocks (List.map expand_dims params) in
  10. FunDec (ret_type, name, params, loc)
  11. | FunUse (funcall, fundef, depth) ->
  12. FunUse (expand_dims funcall, expand_dims fundef, depth)
  13. | FunCall (name, args, loc) ->
  14. FunCall (name, flatten_blocks (List.map expand_dims args), loc)
  15. (* Add additional parameters for array dimensions *)
  16. | Param (ArrayDec (ctype, dims), name, loc) ->
  17. let rec do_expand = function
  18. | [] ->
  19. [Param (ArraySpec (ctype, list_size dims), name, loc)]
  20. | Dim (name, loc) :: tail ->
  21. Param (Int, name, loc) :: (do_expand tail)
  22. | _ -> raise InvalidNode
  23. in
  24. Block (do_expand dims)
  25. (* Add additional function arguments for array dimensions *)
  26. | Arg (VarUse (var, ArrayDec (ctype, dims), depth)) ->
  27. let rec do_expand = function
  28. | [] ->
  29. let spec = ArraySpec (ctype, list_size dims) in
  30. [Arg (VarUse (var, spec, depth))]
  31. | Dim (name, _) :: tl ->
  32. Arg (VarUse (Var (name, noloc), Int, depth)) :: (do_expand tl)
  33. | _ -> raise InvalidNode
  34. in
  35. Block (do_expand dims)
  36. | node -> transform_children expand_dims node
  37. let rec phase input =
  38. prerr_endline "- Expand array dimensions";
  39. match input with
  40. | Ast (node, args) -> Ast (expand_dims node, args)
  41. | _ -> raise (InvalidInput "expand dimensions")