在 PyTorch 中按列表在轴上进行索引

Indexing on axis by list in PyTorch

我有大小为 (10L,) 的变量 lengths_X 和大小为 (10L, 16L, 5L) 的变量。

我想使用 lengths_X 沿 A 的第二个轴进行索引。换句话说,我想获得一个大小为 (10L, 5L) 的新张量 predicted_Y,索引轴 1在 i 对于轴 0 中索引为 i 的所有条目。

在 PyTorch 中执行此操作的最佳方法是什么?

您要查找的实际上称为 batched_index_select,我之前曾寻找过此类功能,但在 PyTorch 中找不到任何可以完成这项工作的本机函数。但是我们可以简单地使用:

A = torch.randn(10, 16, 5)
index = torch.from_numpy(numpy.random.randint(0, 16, size=10))
B = torch.stack([a[i] for a, i in zip(A, index)])

你可以看到讨论here. You can also check out the function batched_index_select provided in the AllenNLP库。我很高兴知道是否有更好的解决方案。