Pytorch 批量索引

Pytorch batch indexing

所以我的网络输出是这样的:

output = tensor([[[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.0315, -0.1837],
     [ 0.0318, -0.1828],
     [ 0.0322, -0.1822],
     [ 0.0324, -0.1819],
     [ 0.0327, -0.1817],
     [ 0.0328, -0.1815],
     [ 0.0330, -0.1815],
     [ 0.0331, -0.1814],
     [ 0.0332, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]]])

这是[8, 24, 2]的形状 现在 8 是我的批量大小。我想从以下位置的每个批次中获取一个数据点:

index = tensor([24, 10,  3,  3,  1,  1,  1,  0])

所以第一批的第 24 个值,第二批的第 10 个值,依此类推。

现在我在理解语法时遇到了问题。 我试过了

torch.gather(output, 0, index)

但它一直告诉我,我的尺寸不匹配。 并尝试

output[ : ,index]

给我获取每批次所有索引的值。 获取这些值的正确语法是什么?

要select每批只需要一个元素,您需要枚举批次索引,这可以通过torch.arange轻松完成。

output[torch.arange(output.size(0)), index]

这实质上是在枚举张量和您的 index 张量之间创建元组以访问数据,从而生成索引 output[0, 24]output[1, 10]

首先注意一点,对于输出形状 [8, 24, 2],第二维的最大索引可以是 23,所以我将您的索引修改为

index = torch.tensor([23, 10,  3,  3,  1,  1,  1,  0])
output = torch.randn((8,24,2)) # Toy data to represent your output

最简单的解决方案是使用 for 循环

data_pts = torch.zeros((8,2)) # Tensor to store desired values

for i,j in enumerate(index):
    data_pts[i, :] = output[i, j, :]

但是,如果您想对索引进行矢量化,您只需要所有维度的索引。例如,

data_pts_vectorized = output[range(8), index, :] 

由于你的索引向量是有序的,你可以用range生成第一维索引。

您可以确认这两种方法的结果相同

assert(torch.all(data_pts == data_pts_vectorized))