使用 index_select 将一个 PyTorch 张量索引到另一个
Indexing one PyTorch tensor by another using index_select
我有一个 3 x 3 PyTorch LongTensor,看起来像这样:
A =
[0, 0, 0]
[1, 2, 2]
[1, 2, 3]
我想让它像这样索引一个 4 x 2 FloatTensor:
B =
[0.4, 0.5]
[1.2, 1.4]
[0.8, 1.9]
[2.4, 2.9]
我的预期输出是下面的 2 x 3 x 3 FloatTensor:
C[0,:,:] =
[0.4, 0.4, 0.4]
[1.2, 0.8, 0.8]
[1.2, 0.8, 2.4]
C[1,:,:] =
[0.5, 0.5, 0.5]
[1.4, 1.9, 1.9]
[1.4, 1.9, 2.9]
换句话说,矩阵A
是索引和广播矩阵B
。 A
是B
的索引矩阵,所以这个操作本质上是一个索引操作。
如何使用 torch.index_select()
函数完成此操作?如果解决方案涉及添加或置换维度,那很好。
使用 index_select()
要求索引值在向量中而不是张量中。但只要格式正确,该函数就会为您处理广播。必须做的最后一件事是重塑输出,我相信是由于广播。
将成功执行此操作的一行是
torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)
A.view(-1)
向量化索引矩阵。
__.view(3,-1,2)
重塑回索引矩阵的形状,但考虑了大小为 2 的新额外维度(因为我正在索引 N x 2 矩阵)。
最后,__.permute(2,0,1)
重塑矩阵,以便输出在单独的通道(而不是每一列)中查看 B
的每个维度。
我有一个 3 x 3 PyTorch LongTensor,看起来像这样:
A =
[0, 0, 0]
[1, 2, 2]
[1, 2, 3]
我想让它像这样索引一个 4 x 2 FloatTensor:
B =
[0.4, 0.5]
[1.2, 1.4]
[0.8, 1.9]
[2.4, 2.9]
我的预期输出是下面的 2 x 3 x 3 FloatTensor:
C[0,:,:] =
[0.4, 0.4, 0.4]
[1.2, 0.8, 0.8]
[1.2, 0.8, 2.4]
C[1,:,:] =
[0.5, 0.5, 0.5]
[1.4, 1.9, 1.9]
[1.4, 1.9, 2.9]
换句话说,矩阵A
是索引和广播矩阵B
。 A
是B
的索引矩阵,所以这个操作本质上是一个索引操作。
如何使用 torch.index_select()
函数完成此操作?如果解决方案涉及添加或置换维度,那很好。
使用 index_select()
要求索引值在向量中而不是张量中。但只要格式正确,该函数就会为您处理广播。必须做的最后一件事是重塑输出,我相信是由于广播。
将成功执行此操作的一行是
torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)
A.view(-1)
向量化索引矩阵。
__.view(3,-1,2)
重塑回索引矩阵的形状,但考虑了大小为 2 的新额外维度(因为我正在索引 N x 2 矩阵)。
最后,__.permute(2,0,1)
重塑矩阵,以便输出在单独的通道(而不是每一列)中查看 B
的每个维度。