dim_reduce.ml 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. open Ast
  2. open Util
  3. let rec multiply = function
  4. | [] -> raise InvalidNode
  5. | [node] -> node
  6. | hd :: tl -> Binop (Mul, hd, multiply tl, noloc)
  7. let rec expand dims depth = function
  8. | [] -> raise InvalidNode
  9. | [node] -> dim_reduce depth node
  10. | hd :: tl -> let trav = dim_reduce depth in
  11. let mul = Binop (Mul, trav hd, trav (List.hd dims), noloc) in
  12. Binop (Mul, mul, expand (List.tl dims) depth tl, noloc)
  13. and dim_reduce depth = function
  14. | Allocate (name, dims, dec, loc) ->
  15. Allocate (name, [multiply dims], dec, loc)
  16. | FunDef (export, ret_type, name, params, body, loc) ->
  17. let trav = dim_reduce (depth + 1) in
  18. FunDef (export, ret_type, name, List.map trav params, trav body, loc)
  19. | Dim (name, loc) ->
  20. VarUse (Var (name, loc), Int, depth)
  21. | VarUse (Type (Deref (name, values, loc), t), (Array (_, dims) as ctype), depth) ->
  22. let reduced = [expand (List.rev dims) depth values] in
  23. VarUse (Type (Deref (name, reduced, loc), t), ctype, depth)
  24. | VarLet (Assign (name, Some values, value, loc), (Array (_, dims) as ctype), depth) ->
  25. let reduced = Some [expand (List.rev dims) depth values] in
  26. VarLet (Assign (name, reduced, dim_reduce depth value, loc), ctype, depth)
  27. | node -> transform_children (dim_reduce depth) node
  28. let rec phase input =
  29. prerr_endline "- Array dimension reduction";
  30. match input with
  31. | Ast node -> Ast (dim_reduce 0 node)
  32. | _ -> raise (InvalidInput "dimension reduction")