如何使用 torch.min 中的索引来索引多维张量
How to use indices from torch.min to index multi-dimensional tensor
使用 PyTorch,我有一个大小为 (b, 2, x, y)
的多维张量 A
,以及另一个大小为 (b, 2, x, y, 3)
.
的相关张量 B
我想得到 A
中跨越 dim=1
的最小值的索引(这个维度大小为 2),并将这个索引张量应用到 B
这样我就可以最终得到一个形状为 (b, x, y, 3)
.
的张量
通过使用 A_mins, indices = torch.min(A, dim=1)
我能够得到形状为 (b, x, y)
的张量 indices
其中值是 0
或 1
取决于哪个是 A
中 dim=1
的最小值。我不知道如何将其应用于 B
以获得所需的输出。我知道 torch.index_select
做了类似的工作,但只针对一维索引向量。
我认为更合适的函数是 torch.gather
. You should first apply torch.Tensor.argmin
(或与 torch.Tensor.min
相同),将 keepdim
选项设置为 True
并广播索引器(此处减少A
因为索引张量 B
有一个额外的维度):
>>> indexer = A.argmin(1,True).unsqueeze(-1).expand(*(-1,)*A.ndim, 3)
>>> out = torch.gather(B, 1, indexer)[:, 0]
形状方面:
indexer
张量将具有 (b, 1, x, y, 3)
的形状,其中最后一个维度本质上是对值的视图(我们从单例扩展到 three-channel torch.expand
).
生成的张量 out
在用 squeeze(1)
或等效方式挤压 dim=1
上的单例后将具有 (b, x, y, 3)
的形状使用 [:, 0]
索引...
使用 PyTorch,我有一个大小为 (b, 2, x, y)
的多维张量 A
,以及另一个大小为 (b, 2, x, y, 3)
.
B
我想得到 A
中跨越 dim=1
的最小值的索引(这个维度大小为 2),并将这个索引张量应用到 B
这样我就可以最终得到一个形状为 (b, x, y, 3)
.
通过使用 A_mins, indices = torch.min(A, dim=1)
我能够得到形状为 (b, x, y)
的张量 indices
其中值是 0
或 1
取决于哪个是 A
中 dim=1
的最小值。我不知道如何将其应用于 B
以获得所需的输出。我知道 torch.index_select
做了类似的工作,但只针对一维索引向量。
我认为更合适的函数是 torch.gather
. You should first apply torch.Tensor.argmin
(或与 torch.Tensor.min
相同),将 keepdim
选项设置为 True
并广播索引器(此处减少A
因为索引张量 B
有一个额外的维度):
>>> indexer = A.argmin(1,True).unsqueeze(-1).expand(*(-1,)*A.ndim, 3)
>>> out = torch.gather(B, 1, indexer)[:, 0]
形状方面:
indexer
张量将具有(b, 1, x, y, 3)
的形状,其中最后一个维度本质上是对值的视图(我们从单例扩展到 three-channeltorch.expand
).生成的张量
out
在用squeeze(1)
或等效方式挤压dim=1
上的单例后将具有(b, x, y, 3)
的形状使用[:, 0]
索引...