使用 tf.gather 或 tf.gather_nd

Using tf.gather or tf.gather_nd

我有一个维度 (BATCH_SIZE*A*B*FEATURE_LENGTH) 的输入。现在我想从每个输入样本的每个 A 块中 select k(out of B) 行。每个 A 块的 k 值都不同。 例如

inp = ([[[[ 5, 38, 40, 13, 28],
         [12,  6, 36, 20, 23],
         [44, 35, 23, 46,  3]],

        [[22, 32, 36, 20, 42],
         [ 0, 19, 41, 36, 17],
         [ 9, 35, 44,  7, 19]],

        [[27, 10, 22, 10, 48],
         [16, 42, 27,  7, 38],
         [35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)

现在说 k=2 即我想从 3 个块中的每一个中提取 2 行。我要

row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block

这意味着我希望我的输出看起来像

([[[[ 5, 38, 40, 13, 28],
    [12,  6, 36, 20, 23]],
   
    [[ 0, 19, 41, 36, 17],
     [ 9, 35, 44,  7, 19]],

    [[27, 10, 22, 10, 48],
     [35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)

我发现如果我们提供索引

,那么使用 tf.gather_nd 是可能的
ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])

但是如果我输入大小为 (1,16,16,128)k=4,创建这个长索引序列将变得乏味。 在 Tensorflow-2 中有更简单的方法吗? 谢谢!

tf.gather()batch_dims 参数一起使用:

inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)