dim_reduce.ml 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. open Ast
  2. open Util
  3. let rec multiply = function
  4. | [] -> raise InvalidNode
  5. | [node] -> node
  6. | hd :: tl -> Binop (Mul, hd, multiply tl, noloc)
  7. let rec expand dims = function
  8. | [] -> raise InvalidNode
  9. | [node] -> dim_reduce node
  10. | hd :: tl -> let mul = Binop (Mul, dim_reduce hd, (List.hd dims), noloc) in
  11. Binop (Mul, mul, expand (List.tl dims) tl, noloc)
  12. and dim_reduce = function
  13. | Allocate (name, dims, dec, loc) ->
  14. Allocate (name, [multiply dims], dec, loc)
  15. | VarUse (Type (Deref (name, values, loc), t), (Array (_, dims) as ctype), depth) ->
  16. let reduced = [expand (List.rev dims) values] in
  17. VarUse (Type (Deref (name, reduced, loc), t), ctype, depth)
  18. | VarLet (Assign (name, Some values, value, loc), (Array (_, dims) as ctype), depth) ->
  19. let reduced = Some [expand (List.rev dims) values] in
  20. VarLet (Assign (name, reduced, dim_reduce value, loc), ctype, depth)
  21. | node -> transform_children dim_reduce node
  22. let rec phase input =
  23. prerr_endline "- Array dimension reduction";
  24. match input with
  25. | Ast node -> Ast (dim_reduce node)
  26. | _ -> raise (InvalidInput "dimension reduction")