torch 中的单次多维索引 - 也许使用 index_select 或收集?

Single shot multi-dimension indexing in torch - perhaps with index_select or gather?

我正在根据对应数据对矩阵执行多索引重新排列。现在,我通过一对 index_select 调用来执行此操作,但这是非常低的内存效率(就内存使用而言为 n^2),并且在计算效率方面也不是很理想。有什么方法可以将我的操作简化为单个 .gather 或 .index_select 调用?

我本质上想要做的是,当给定一个形状为 (I,J,K) 的源数组和一个形状为 (I,J,2) 的索引数组时,产生满足条件的结果:

result[i][j][:] = source[idx[i][j][0]] [idx[i][j][1]] [:]

这是一个 运行我现在如何做事的可用玩具示例:

source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
indices = torch.tensor([[[2,2],[3,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]]])

ax1 = torch.index_select(source,0,indices[:,:,0].flatten())
ax2 = torch.index_select(ax1, 1, indices[:,:,1].flatten())

result = ax2.diagonal().reshape(indices.shape(0), indices.shape(1))

这种方法只适用于我,因为我的图像相当小,所以即使存在对角化问题,它们也适合内存。无论如何,我正在生成大量不需要的数据。此外,如果 K 变大,那么这个问题会呈指数级恶化。也许我只是在文档中遗漏了一些明显的东西,但我觉得这是其他人必须 运行 解决的问题,然后才能帮助我!

您已经为 integer array indexing 准备好了索引,所以我们可以简单地做

result = source[indices[..., 0], indices[..., 1], ...]