Futhark 尺寸不匹配问题

Futhark dimension mismatch issue

我正在尝试实现 groupby 功能。我已经推断我的代码应该是正确的。

这是重要的部分

let type_func (typ: i32) (v1 : u32)  (v2: u32) : u32 =
        match typ
        case 1 -> (*) v1 v2
        case 2 -> (+) v1 v2
        case 3 -> (u32.max) v1 v2
        case 4 -> (u32.min) v1 v2
        case x -> (u32.min) v1 v2 -- TODO change to some panic function


let merge [n][m] (s_cols_t: [n]i32) (a: [m]u32) (b: [m]u32) : [m]u32 =
  map (\i -> if i == 0 then a[i]
                       else (type_func s_cols_t[i-1] a[i] b[i])
      ) (iota m)

let main [n][m][t] (db : [n][m]u32)  (g_col: i32) (s_cols: [t]i32) (t_cols: [t]i32) : [][]u32 =
  let keep_g = db[:, g_col]
  let keep_s_cols = map (\c -> db[:, c]) s_cols
  let keep_inter = concat [keep_g] (keep_s_cols)
  let keep = transpose keep_inter
  let sorted_rows = rsort keep -- ideally pass groupby col here
  let idxs = mk_flags sorted_rows[:, 0]
  let flag = map (== 1) idxs
  let helper = merge t_cols
  in segmented_reduce helper (replicate (length keep_inter) 0)  flag sorted_rows

但是编译器抛出以下错误。

[0]> :l groupby.fut
Loading groupby.fut
Error at groupby.fut:63:70-80 :
Cannot apply "segmented_reduce" to "sorted_rows" (invalid type).
Expected: [n][argdim₃₅]u32
Actual:   *[n][ret₁₃]u32

Dimensions "argdim₃₅" and "ret₁₃" do not match.

Note: "argdim₃₅" is value of argument
        length keep_inter
      passed to "replicate" at 63:42-58.

Note: "ret₁₃" is unknown size returned by "concat" at 57:20-48.

我通过在 REPL 中逐行输入代码,手动检查了 argdim₃₅ 和 ret₁₃ 的维度是否匹配。这仅仅是我遇到的编译器限制还是我在做一些愚蠢的事情?

我有一个 hacky 修复程序来解决这个问题。您可以将大部分主要功能放入另一个函数中,然后显式给出函数定义中使用的维度的形状。

let groupby [n][m][s][t] (db : [n][m]u32)  (cols: [s]i32)  (t_cols: [t]i32) : [][]u32 =
  let keep_fun columns row = map (\i -> row[i]) columns
  let keep = map (keep_fun cols) db
  let sorted_rows = rsort keep -- ideally pass groupby col here
  let idxs = mk_flags sorted_rows[:, 0]
  let flag = map (== 1) idxs
  let helper = merge t_cols
  in segmented_reduce helper (replicate s 0)  flag sorted_rows

let main db g_col s_cols t_cols =
  let cols = concat [g_col] s_cols
  in groupby db cols t_cols

这是一个极其野蛮和丑陋的解决方案,但如果可行,那就行。