ソースを参照

Fixed index calculation in array dimension reduction

Taddeus Kroes 12 年 前
コミット
0b8bbab84a
1 ファイル変更21 行追加14 行削除
  1. 21 14
      phases/dim_reduce.ml

+ 21 - 14
phases/dim_reduce.ml

@@ -25,7 +25,7 @@ let rec expand_dims = function
         Block (do_expand dims)
 
     (* Add additional function arguments for array dimensions *)
-    | Arg (VarUse (VarDec (ArrayDims (ctype, dims), name, None, decann), None, ann)) as node ->
+    | Arg (VarUse (VarDec (ArrayDims (ctype, dims), name, None, decann), None, ann)) ->
         let rec do_expand = function
             | [] ->
                 (* Remove the (now obsolete dimensions fromt the type) *)
@@ -51,18 +51,25 @@ let rec multiply = function
     | [node]   -> node
     | hd :: tl -> Binop (Mul, hd, multiply tl, [Type Int])
 
-let use_dim depth = function
-    | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
-    (*| VarUse (dim, None, ann) -> VarUse ()*)
-    | node -> node
-
-let rec expand depth dims = function
+let rec multiply_all = function
     | []       -> raise InvalidNode
-    | [node]   -> dim_reduce depth node
-    | hd :: tl ->
-        let dim = use_dim depth (List.hd dims) in
-        let mul = Binop (Mul, dim_reduce depth hd, dim, [Type Int]) in
-        Binop (Add, mul, expand depth (List.tl dims) tl, [Type Int])
+    | [node]   -> node
+    | hd :: tl -> Binop (Mul, hd, multiply_all tl, [])
+
+let rec expand depth dims =
+    let rec do_expand dims = function
+        | []       -> raise InvalidNode
+        | [node]   -> dim_reduce depth node
+        | i :: j :: tl ->
+            let parent_width = List.hd dims in
+            let mul = Binop (Mul, dim_reduce depth i, parent_width, [Type Int]) in
+            do_expand (List.tl dims) (Binop (Add, mul, j, [Type Int]) :: tl)
+    in
+    let use_dim = function
+        | Dim _ as dim -> VarUse (dim, None, [Type Int; Depth depth])
+        | node -> node
+    in
+    do_expand (List.map use_dim (List.tl dims))
 
 and dim_reduce depth = function
     | Allocate (dec, dims, ann) ->
@@ -77,7 +84,7 @@ and dim_reduce depth = function
     | VarUse (dec, Some values, ann) as node ->
         (match typeof dec with
         | ArrayDims (_, dims) ->
-            VarUse (dec, Some [expand depth (List.rev dims) values], ann)
+            VarUse (dec, Some [expand depth dims values], ann)
         | _ -> node
         )
 
@@ -85,7 +92,7 @@ and dim_reduce depth = function
     | VarLet (dec, Some values, value, ann) as node ->
         (match typeof dec with
         | ArrayDims (_, dims) ->
-            VarLet (dec, Some [expand depth (List.rev dims) values], value, ann)
+            VarLet (dec, Some [expand depth dims values], value, ann)
         | _ -> node
         )