Torch.gather 来自使用二维索引的一维数组

Torch.gather from 1D array using 2D indices

我有一个 nx1 张量和一个 nxm 张量。我想使用 nxm 张量从 nx1 张量收集值。 例如 对于输入 tensor([1, 2, 3, 4])

索引tensor([[0, 3], [2, 1],[1, 3], [2,3]])

输出应该是

tensor([[1, 4], [3, 2], [2,4], [3,4])

索引在二维矩阵中,值将从一维列表中收集。

如何使用 torch.gather/ 或任何火炬张量函数来达到此目的? 我的以下代码给出了错误

t = torch.tensor([[1, 2, 3, 4]])
ind = torch.tensor([[0, 3], [2, 1],[1, 3], [2,3]])
torch.gather(t, 0, ind)

RuntimeError: index 2 is out of bounds for dimension 0 with size 1

编辑: 您可以进行简单的索引来实现此输出。

t[ind]

这是最好的方法吗?我假设这涉及广播输入数组。

编辑

在正向传递中使用 t[ind] 会导致错误

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [430,0,0], thread: [97,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.

当我尝试在正向传递中打印张量时,在 t[ind] 操作后没有显示任何输出。这是有道理的,因为 getitem 不是传播损失的可微分操作。 因此,在 getitem.

上使用 gather 是一个有效的用例

如果你想使用torch.gather:

torch.gather(t.expand(4, -1), 1, ind)