dim_reduce.ml 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. open Ast
  2. open Util
  3. let rec expand_dims = function
  4. (* Flatten Block nodes returned by transformations below*)
  5. | FunDef (export, ret_type, name, params, body, loc) as node ->
  6. let params = flatten_blocks (List.map expand_dims params) in
  7. FunDef (export, ret_type, name, params, expand_dims body, loc)
  8. | FunDec (ret_type, name, params, loc) ->
  9. let params = flatten_blocks (List.map expand_dims params) in
  10. FunDec (ret_type, name, params, loc)
  11. | FunCall (name, args, loc) as node ->
  12. FunCall (name, flatten_blocks (List.map expand_dims args), loc)
  13. (* Add additional parameters for array dimensions *)
  14. | Param (ArrayDec (_, dims), name, _) as node ->
  15. let rec do_expand = function
  16. | [] -> [node]
  17. | Dim (name, loc) :: tail ->
  18. Param (Int, name, loc) :: (do_expand tail)
  19. | _ -> raise InvalidNode
  20. in
  21. Block (do_expand dims)
  22. (* Add additional function arguments for array dimensions *)
  23. | Arg (VarUse (_, ArrayDec (_, dims), _)) as node ->
  24. let rec do_expand = function
  25. | [] -> [node]
  26. | Dim (name, _) :: tail ->
  27. Var (name, noloc) :: (do_expand tail)
  28. | _ -> raise InvalidNode
  29. in
  30. Block (do_expand dims)
  31. | node -> transform_children expand_dims node
  32. (*
  33. let rec array_init = function
  34. (* transform scalar assignment into nested for loops *)
  35. | Assign (name, ArrayScalar (value)) ->
  36. let rec add_loop indices = function
  37. | [] ->
  38. Assign (Deref (name, indices), value)
  39. | dim :: rest ->
  40. let counter = fresh_var "counter" in
  41. let ind = (indices @ [Var counter]) in
  42. For (counter, IntConst 0, dim, IntConst 1, add_loop ind rest)
  43. in
  44. add_loop [] dims
  45. | Assign (name, ArrayConst (dims)) -> Block []
  46. | node -> transform array_init node
  47. *)
  48. let rec phase input =
  49. prerr_endline "- Array dimension reduction";
  50. match input with
  51. | Ast (node, args) ->
  52. Ast (expand_dims node, args)
  53. | _ -> raise (InvalidInput "dimension reduction")