Futhark 尺寸不匹配问题
Futhark dimension mismatch issue
我正在尝试实现 groupby 功能。我已经推断我的代码应该是正确的。
这是重要的部分
let type_func (typ: i32) (v1 : u32) (v2: u32) : u32 =
match typ
case 1 -> (*) v1 v2
case 2 -> (+) v1 v2
case 3 -> (u32.max) v1 v2
case 4 -> (u32.min) v1 v2
case x -> (u32.min) v1 v2 -- TODO change to some panic function
let merge [n][m] (s_cols_t: [n]i32) (a: [m]u32) (b: [m]u32) : [m]u32 =
map (\i -> if i == 0 then a[i]
else (type_func s_cols_t[i-1] a[i] b[i])
) (iota m)
let main [n][m][t] (db : [n][m]u32) (g_col: i32) (s_cols: [t]i32) (t_cols: [t]i32) : [][]u32 =
let keep_g = db[:, g_col]
let keep_s_cols = map (\c -> db[:, c]) s_cols
let keep_inter = concat [keep_g] (keep_s_cols)
let keep = transpose keep_inter
let sorted_rows = rsort keep -- ideally pass groupby col here
let idxs = mk_flags sorted_rows[:, 0]
let flag = map (== 1) idxs
let helper = merge t_cols
in segmented_reduce helper (replicate (length keep_inter) 0) flag sorted_rows
但是编译器抛出以下错误。
[0]> :l groupby.fut
Loading groupby.fut
Error at groupby.fut:63:70-80 :
Cannot apply "segmented_reduce" to "sorted_rows" (invalid type).
Expected: [n][argdim₃₅]u32
Actual: *[n][ret₁₃]u32
Dimensions "argdim₃₅" and "ret₁₃" do not match.
Note: "argdim₃₅" is value of argument
length keep_inter
passed to "replicate" at 63:42-58.
Note: "ret₁₃" is unknown size returned by "concat" at 57:20-48.
我通过在 REPL 中逐行输入代码,手动检查了 argdim₃₅ 和 ret₁₃ 的维度是否匹配。这仅仅是我遇到的编译器限制还是我在做一些愚蠢的事情?
我有一个 hacky 修复程序来解决这个问题。您可以将大部分主要功能放入另一个函数中,然后显式给出函数定义中使用的维度的形状。
let groupby [n][m][s][t] (db : [n][m]u32) (cols: [s]i32) (t_cols: [t]i32) : [][]u32 =
let keep_fun columns row = map (\i -> row[i]) columns
let keep = map (keep_fun cols) db
let sorted_rows = rsort keep -- ideally pass groupby col here
let idxs = mk_flags sorted_rows[:, 0]
let flag = map (== 1) idxs
let helper = merge t_cols
in segmented_reduce helper (replicate s 0) flag sorted_rows
let main db g_col s_cols t_cols =
let cols = concat [g_col] s_cols
in groupby db cols t_cols
这是一个极其野蛮和丑陋的解决方案,但如果可行,那就行。
我正在尝试实现 groupby 功能。我已经推断我的代码应该是正确的。
这是重要的部分
let type_func (typ: i32) (v1 : u32) (v2: u32) : u32 =
match typ
case 1 -> (*) v1 v2
case 2 -> (+) v1 v2
case 3 -> (u32.max) v1 v2
case 4 -> (u32.min) v1 v2
case x -> (u32.min) v1 v2 -- TODO change to some panic function
let merge [n][m] (s_cols_t: [n]i32) (a: [m]u32) (b: [m]u32) : [m]u32 =
map (\i -> if i == 0 then a[i]
else (type_func s_cols_t[i-1] a[i] b[i])
) (iota m)
let main [n][m][t] (db : [n][m]u32) (g_col: i32) (s_cols: [t]i32) (t_cols: [t]i32) : [][]u32 =
let keep_g = db[:, g_col]
let keep_s_cols = map (\c -> db[:, c]) s_cols
let keep_inter = concat [keep_g] (keep_s_cols)
let keep = transpose keep_inter
let sorted_rows = rsort keep -- ideally pass groupby col here
let idxs = mk_flags sorted_rows[:, 0]
let flag = map (== 1) idxs
let helper = merge t_cols
in segmented_reduce helper (replicate (length keep_inter) 0) flag sorted_rows
但是编译器抛出以下错误。
[0]> :l groupby.fut
Loading groupby.fut
Error at groupby.fut:63:70-80 :
Cannot apply "segmented_reduce" to "sorted_rows" (invalid type).
Expected: [n][argdim₃₅]u32
Actual: *[n][ret₁₃]u32
Dimensions "argdim₃₅" and "ret₁₃" do not match.
Note: "argdim₃₅" is value of argument
length keep_inter
passed to "replicate" at 63:42-58.
Note: "ret₁₃" is unknown size returned by "concat" at 57:20-48.
我通过在 REPL 中逐行输入代码,手动检查了 argdim₃₅ 和 ret₁₃ 的维度是否匹配。这仅仅是我遇到的编译器限制还是我在做一些愚蠢的事情?
我有一个 hacky 修复程序来解决这个问题。您可以将大部分主要功能放入另一个函数中,然后显式给出函数定义中使用的维度的形状。
let groupby [n][m][s][t] (db : [n][m]u32) (cols: [s]i32) (t_cols: [t]i32) : [][]u32 =
let keep_fun columns row = map (\i -> row[i]) columns
let keep = map (keep_fun cols) db
let sorted_rows = rsort keep -- ideally pass groupby col here
let idxs = mk_flags sorted_rows[:, 0]
let flag = map (== 1) idxs
let helper = merge t_cols
in segmented_reduce helper (replicate s 0) flag sorted_rows
let main db g_col s_cols t_cols =
let cols = concat [g_col] s_cols
in groupby db cols t_cols
这是一个极其野蛮和丑陋的解决方案,但如果可行,那就行。