dim_reduce.ml 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. open Types
  2. open Util
  3. let rec multiply = function
  4. | [] -> raise InvalidNode
  5. | [node] -> node
  6. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  7. let use_dim depth = function
  8. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  9. (*| VarUse (dim, None, ann) -> VarUse ()*)
  10. | node -> node
  11. let rec expand depth dims = function
  12. | [] -> raise InvalidNode
  13. | [node] -> dim_reduce depth node
  14. | hd :: tl ->
  15. let dim = use_dim depth (List.hd dims) in
  16. let mul = Binop (Mul, dim_reduce depth hd, dim, [Type Int]) in
  17. Binop (Add, mul, expand depth (List.tl dims) tl, [Type Int])
  18. and dim_reduce depth = function
  19. | Allocate (name, dims, dec, ann) ->
  20. Allocate (name, [multiply dims], dec, ann)
  21. (* Increase nesting depth when goiing into function *)
  22. | FunDef (export, ret_type, name, params, body, ann) ->
  23. let trav = dim_reduce (depth + 1) in
  24. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  25. (* Expand indices when dereferencing *)
  26. | VarUse (VarDec (Array (_, dims), _, _, _) as dec, Some values, ann) ->
  27. VarUse (dec, Some [expand depth (List.rev dims) values], ann)
  28. (* Expand indices when assigning to array index *)
  29. | VarLet (VarDec (Array (_, dims), _, _, _) as dec, Some values, value, ann) ->
  30. VarLet (dec, Some [expand depth (List.rev dims) values], value, ann)
  31. | node -> transform_children (dim_reduce depth) node
  32. let rec simplify_decs = function
  33. | VarDec (Array (ctype, dims), name, init, ann) ->
  34. VarDec (FlatArray ctype, name, init, ann)
  35. | Param (Array (ctype, dims), name, ann) ->
  36. Param (FlatArray ctype, name, ann)
  37. | node -> transform_children simplify_decs node
  38. let rec phase input =
  39. log_line 2 "- Array dimension reduction";
  40. match input with
  41. | Types node -> Types (simplify_decs (dim_reduce 0 node))
  42. | _ -> raise (InvalidInput "dimension reduction")