具有 2 个相同形状的张量的 pytorch 中的 Argmax 索引
Argmax indexing in pytorch with 2 tensors of equal shape
问题总结
我在 pytorch 中处理高维张量,我需要用另一个张量的 argmax 值索引一个张量。所以我需要用 dim [3,4]
的张量 x
和 dim [3,4]
的 argmax 的结果索引张量 y
of dim [3,4]
。如果张量是:
import torch as T
# Tensor to get argmax from
# expected argmax: [2, 0, 1]
x = T.tensor([[1, 2, 8, 3],
[6, 3, 3, 5],
[2, 8, 1, 7]])
# Tensor to index with argmax from preivous
# expected tensor to retrieve [2, 4, 9]
y = T.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
# argmax
x_max, x_argmax = T.max(x, dim=1)
我想要一个操作,给定 x
或 x_argmax
的 argmax 索引,在相同索引 x_argmax
索引中检索张量 y
中的值。
描述您尝试过的内容
这是我试过的:
# What I have tried
print(y[x_argmax])
print(y[:, x_argmax])
print(y[..., x_argmax])
print(y[x_argmax.unsqueeze(1)])
我已经阅读了很多关于 numpy 索引、基本索引、高级索引和组合索引的内容。我一直在尝试使用组合索引(因为我想要张量的第一维切片和第二维索引值)。但是我还没能为这个用例想出解决方案。
y[T.arange(3), x_argmax]
怎么样?
这适合我...
说明:调用T.max(x, dim=1)
时带走了维度信息,因此需要显式恢复此信息。
您正在寻找torch.gather
:
idx = torch.argmax(x, dim=1, keepdim=true) # get argmax directly, w/o max
out = torch.gather(y, 1, idx)
结果为
tensor([[2],
[4],
[9]])
问题总结
我在 pytorch 中处理高维张量,我需要用另一个张量的 argmax 值索引一个张量。所以我需要用 dim [3,4]
的张量 x
和 dim [3,4]
的 argmax 的结果索引张量 y
of dim [3,4]
。如果张量是:
import torch as T
# Tensor to get argmax from
# expected argmax: [2, 0, 1]
x = T.tensor([[1, 2, 8, 3],
[6, 3, 3, 5],
[2, 8, 1, 7]])
# Tensor to index with argmax from preivous
# expected tensor to retrieve [2, 4, 9]
y = T.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
# argmax
x_max, x_argmax = T.max(x, dim=1)
我想要一个操作,给定 x
或 x_argmax
的 argmax 索引,在相同索引 x_argmax
索引中检索张量 y
中的值。
描述您尝试过的内容
这是我试过的:
# What I have tried
print(y[x_argmax])
print(y[:, x_argmax])
print(y[..., x_argmax])
print(y[x_argmax.unsqueeze(1)])
我已经阅读了很多关于 numpy 索引、基本索引、高级索引和组合索引的内容。我一直在尝试使用组合索引(因为我想要张量的第一维切片和第二维索引值)。但是我还没能为这个用例想出解决方案。
y[T.arange(3), x_argmax]
怎么样?
这适合我...
说明:调用T.max(x, dim=1)
时带走了维度信息,因此需要显式恢复此信息。
您正在寻找torch.gather
:
idx = torch.argmax(x, dim=1, keepdim=true) # get argmax directly, w/o max
out = torch.gather(y, 1, idx)
结果为
tensor([[2],
[4],
[9]])