使用 index_select 将一个 PyTorch 张量索引到另一个

Indexing one PyTorch tensor by another using index_select

我有一个 3 x 3 PyTorch LongTensor,看起来像这样:

A = 
    [0, 0, 0]
    [1, 2, 2]
    [1, 2, 3]

我想让它像这样索引一个 4 x 2 FloatTensor:

B = 
    [0.4, 0.5]
    [1.2, 1.4]
    [0.8, 1.9]
    [2.4, 2.9]

我的预期输出是下面的 2 x 3 x 3 FloatTensor:

C[0,:,:] = 
    [0.4, 0.4, 0.4]
    [1.2, 0.8, 0.8]
    [1.2, 0.8, 2.4]

C[1,:,:] =
    [0.5, 0.5, 0.5]
    [1.4, 1.9, 1.9]
    [1.4, 1.9, 2.9]

换句话说,矩阵A是索引和广播矩阵BAB的索引矩阵,所以这个操作本质上是一个索引操作。

如何使用 torch.index_select() 函数完成此操作?如果解决方案涉及添加或置换维度,那很好。

使用 index_select() 要求索引值在向量中而不是张量中。但只要格式正确,该函数就会为您处理广播。必须做的最后一件事是重塑输出,我相信是由于广播。

将成功执行此操作的一行是

torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)

A.view(-1) 向量化索引矩阵。

__.view(3,-1,2) 重塑回索引矩阵的形状,但考虑了大小为 2 的新额外维度(因为我正在索引 N x 2 矩阵)。

最后,__.permute(2,0,1) 重塑矩阵,以便输出在单独的通道(而不是每一列)中查看 B 的每个维度。