dim_reduce.ml 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. open Types
  2. open Util
  3. let rec multiply = function
  4. | [] -> raise InvalidNode
  5. | [node] -> node
  6. | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
  7. let use_dim depth = function
  8. | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
  9. (*| VarUse (dim, None, ann) -> VarUse ()*)
  10. | node -> node
  11. let rec expand depth dims = function
  12. | [] -> raise InvalidNode
  13. | [node] -> dim_reduce depth node
  14. | hd :: tl ->
  15. let dim = use_dim depth (List.hd dims) in
  16. let mul = Binop (Mul, dim_reduce depth hd, dim, [Type Int]) in
  17. Binop (Add, mul, expand depth (List.tl dims) tl, [Type Int])
  18. and dim_reduce depth = function
  19. | Allocate (dec, dims, ann) ->
  20. Allocate (dec, [multiply dims], ann)
  21. (* Increase nesting depth when goiing into function *)
  22. | FunDef (export, ret_type, name, params, body, ann) ->
  23. let trav = dim_reduce (depth + 1) in
  24. FunDef (export, ret_type, name, List.map trav params, trav body, ann)
  25. (* Expand indices when dereferencing *)
  26. | VarUse (dec, Some values, ann) as node ->
  27. (match typeof dec with
  28. | Array (_, dims) ->
  29. VarUse (dec, Some [expand depth (List.rev dims) values], ann)
  30. | _ -> node
  31. )
  32. (* Expand indices when assigning to array index *)
  33. | VarLet (dec, Some values, value, ann) as node ->
  34. (match typeof dec with
  35. | Array (_, dims) ->
  36. VarLet (dec, Some [expand depth (List.rev dims) values], value, ann)
  37. | _ -> node
  38. )
  39. | node -> transform_children (dim_reduce depth) node
  40. let rec simplify_decs = function
  41. | VarDec (Array (ctype, dims), name, init, ann) as node ->
  42. VarDec (FlatArray ctype, name, init, ann)
  43. | Param (Array (ctype, dims), name, ann) ->
  44. Param (FlatArray ctype, name, ann)
  45. | node -> transform_children simplify_decs node
  46. let phase = function
  47. | Ast node -> Ast (simplify_decs (dim_reduce 0 node))
  48. | _ -> raise (InvalidInput "dimension reduction")