Pytorch:如何索引张量?

Pytorch: How to index a tensor?

我是 PyTorch 的新手,仍在思考如何形成正确的 gather 语句。我有一个大小为 (1,200,61,1632) 的 4D 输入张量,其中 1632 是时间维度。我想用大小为 (4,1632) 的张量 idx 对其进行索引,其中 idx 的每一行都是我想从 input 张量中提取的值。所以 idx 的行看起来像:

[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...

因此输出的大小为 1632。换句话说,我想这样做:

output = []
for i in range(1632):
  output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])

这是 torch.gather 的合适用例吗? 查看 gather 的文档,它说输入张量和索引张量必须具有相同的形状。

自 PyTorch doesn't offer an implementation of ravel_multi_index 以来,丑陋的做法是:

output = input[idx[0, :], idx[1, :], idx[2, :], idx[3, :]]

在 NumPy 中,您可以这样做:

output = np.take(input, np.ravel_multi_index(idx, input.shape))