如何使用 torch.min 中的索引来索引多维张量

How to use indices from torch.min to index multi-dimensional tensor

使用 PyTorch,我有一个大小为 (b, 2, x, y) 的多维张量 A,以及另一个大小为 (b, 2, x, y, 3).

的相关张量 B

我想得到 A 中跨越 dim=1 的最小值的索引(这个维度大小为 2),并将这个索引张量应用到 B 这样我就可以最终得到一个形状为 (b, x, y, 3).

的张量

通过使用 A_mins, indices = torch.min(A, dim=1) 我能够得到形状为 (b, x, y) 的张量 indices 其中值是 01 取决于哪个是 Adim=1 的最小值。我不知道如何将其应用于 B 以获得所需的输出。我知道 torch.index_select 做了类似的工作,但只针对一维索引向量。

我认为更合适的函数是 torch.gather. You should first apply torch.Tensor.argmin(或与 torch.Tensor.min 相同),将 keepdim 选项设置为 True 并广播索引器(此处减少A 因为索引张量 B 有一个额外的维度):

>>> indexer = A.argmin(1,True).unsqueeze(-1).expand(*(-1,)*A.ndim, 3)
>>> out = torch.gather(B, 1, indexer)[:, 0]

形状方面:

  • indexer 张量将具有 (b, 1, x, y, 3) 的形状,其中最后一个维度本质上是对值的视图(我们从单例扩展到 three-channel torch.expand).

  • 生成的张量 out 在用 squeeze(1) 或等效方式挤压 dim=1 上的单例后将具有 (b, x, y, 3) 的形状使用 [:, 0] 索引...