具有 2 个相同形状的张量的 pytorch 中的 Argmax 索引

Argmax indexing in pytorch with 2 tensors of equal shape

问题总结

我在 pytorch 中处理高维张量,我需要用另一个张量的 argmax 值索引一个张量。所以我需要用 dim [3,4] 的张量 x 和 dim [3,4] 的 argmax 的结果索引张量 y of dim [3,4]。如果张量是:

import torch as T
# Tensor to get argmax from
# expected argmax: [2, 0, 1]
x = T.tensor([[1, 2, 8, 3],
              [6, 3, 3, 5],
              [2, 8, 1, 7]])

# Tensor to index with argmax from preivous
# expected tensor to retrieve [2, 4, 9]
y = T.tensor([[0,  1,  2,  3],
              [4,  5,  6,  7],
              [8,  9, 10, 11]])
# argmax
x_max, x_argmax = T.max(x, dim=1)

我想要一个操作,给定 xx_argmax 的 argmax 索引,在相同索引 x_argmax 索引中检索张量 y 中的值。

描述您尝试过的内容

这是我试过的:

# What I have tried
print(y[x_argmax])
print(y[:, x_argmax])
print(y[..., x_argmax])
print(y[x_argmax.unsqueeze(1)])

我已经阅读了很多关于 numpy 索引、基本索引、高级索引和组合索引的内容。我一直在尝试使用组合索引(因为我想要张量的第一维切片和第二维索引值)。但是我还没能为这个用例想出解决方案。

y[T.arange(3), x_argmax]怎么样?

这适合我...

说明:调用T.max(x, dim=1)时带走了维度信息,因此需要显式恢复此信息。

您正在寻找torch.gather

idx = torch.argmax(x, dim=1, keepdim=true)  # get argmax directly, w/o max
out = torch.gather(y, 1, idx)

结果为

tensor([[2],
        [4],
        [9]])