Pytorch 批量索引
Pytorch batch indexing
所以我的网络输出是这样的:
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
这是[8, 24, 2]
的形状
现在 8 是我的批量大小。我想从以下位置的每个批次中获取一个数据点:
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
所以第一批的第 24 个值,第二批的第 10 个值,依此类推。
现在我在理解语法时遇到了问题。
我试过了
torch.gather(output, 0, index)
但它一直告诉我,我的尺寸不匹配。
并尝试
output[ : ,index]
给我获取每批次所有索引的值。
获取这些值的正确语法是什么?
要select每批只需要一个元素,您需要枚举批次索引,这可以通过torch.arange
轻松完成。
output[torch.arange(output.size(0)), index]
这实质上是在枚举张量和您的 index
张量之间创建元组以访问数据,从而生成索引 output[0, 24]
、output[1, 10]
等
首先注意一点,对于输出形状 [8, 24, 2],第二维的最大索引可以是 23,所以我将您的索引修改为
index = torch.tensor([23, 10, 3, 3, 1, 1, 1, 0])
output = torch.randn((8,24,2)) # Toy data to represent your output
最简单的解决方案是使用 for 循环
data_pts = torch.zeros((8,2)) # Tensor to store desired values
for i,j in enumerate(index):
data_pts[i, :] = output[i, j, :]
但是,如果您想对索引进行矢量化,您只需要所有维度的索引。例如,
data_pts_vectorized = output[range(8), index, :]
由于你的索引向量是有序的,你可以用range
生成第一维索引。
您可以确认这两种方法的结果相同
assert(torch.all(data_pts == data_pts_vectorized))
所以我的网络输出是这样的:
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
这是[8, 24, 2]
的形状
现在 8 是我的批量大小。我想从以下位置的每个批次中获取一个数据点:
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
所以第一批的第 24 个值,第二批的第 10 个值,依此类推。
现在我在理解语法时遇到了问题。 我试过了
torch.gather(output, 0, index)
但它一直告诉我,我的尺寸不匹配。 并尝试
output[ : ,index]
给我获取每批次所有索引的值。 获取这些值的正确语法是什么?
要select每批只需要一个元素,您需要枚举批次索引,这可以通过torch.arange
轻松完成。
output[torch.arange(output.size(0)), index]
这实质上是在枚举张量和您的 index
张量之间创建元组以访问数据,从而生成索引 output[0, 24]
、output[1, 10]
等
首先注意一点,对于输出形状 [8, 24, 2],第二维的最大索引可以是 23,所以我将您的索引修改为
index = torch.tensor([23, 10, 3, 3, 1, 1, 1, 0])
output = torch.randn((8,24,2)) # Toy data to represent your output
最简单的解决方案是使用 for 循环
data_pts = torch.zeros((8,2)) # Tensor to store desired values
for i,j in enumerate(index):
data_pts[i, :] = output[i, j, :]
但是,如果您想对索引进行矢量化,您只需要所有维度的索引。例如,
data_pts_vectorized = output[range(8), index, :]
由于你的索引向量是有序的,你可以用range
生成第一维索引。
您可以确认这两种方法的结果相同
assert(torch.all(data_pts == data_pts_vectorized))