如何通过索引列表检索多维 pytorch 张量中的元素?

How can I retrieve elements in a multidimensional pytorch tensor by a list of indices?

我有两个张量:scoreslists
scores 的形状是 (x, 8)lists(x, 8, 4)。我想过滤 scores 中每一行的最大值,并从 lists.

中过滤相应的元素

以下面为例(形状尺寸8为简单起见减少为2):

scores = torch.tensor([[0.5, 0.4], [0.3, 0.8], ...])
lists = torch.tensor([[[0.2, 0.3, 0.1, 0.5],
                       [0.4, 0.7, 0.8, 0.2]], 
                      [[0.1, 0.2, 0.1, 0.3], 
                       [0.4, 0.3, 0.2, 0.5]], ...])

然后我想将这些张量过滤为:

scores = torch.tensor([0.5, 0.8, ...])
lists = torch.tensor([[0.2, 0.3, 0.1, 0.5], [0.4, 0.3, 0.2, 0.5], ...])

注意: 到目前为止,我已尝试从原始分数向量中检索索引并将其用作索引向量来过滤列表:

# PSEUDO-CODE
indices = scores.argmax(dim=1)
for list, idx in zip(lists, indices):
    list = list[idx]

这也是问题名称的来源。

我想你试过类似的东西

indices = scores.argmax(dim=1)
selection = lists[:, indices]

这不起作用,因为索引是为维度 0 中的每个元素选择的,所以最终形状是 (x, x, 4)

执行正确的选择,您需要用范围替换切片。

indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]