Pytorch,使用多个索引从 3D 张量中检索值。计算效率最高的解决方案

Pytorch, retrieving values from a 3D tensor using several indices. Most computationally efficient solution

相关:

这是另一个关于使用索引列表从 3D 张量中检索值的问题。

在这种情况下,我有一个 3d 张量,例如

b = [[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]]
tensor_b = torch.tensor(b)
tensor_b
tensor([[[ 4, 20],
         [ 1, -1]],

        [[ 1,  2],
         [ 8, -1]],

        [[92,  4],
         [23, -1]]])

在本例中,我有一个 3D 索引列表。所以

indices = [
           [[1, 0, 1], [2, 0, 1]],
           [[1, 1, 1], [0, 0, 0]],
           [[2, 1, 0], [0, 1, 0]]
]

每个三元组都是张量 b 的索引。期望的结果是

[[2, 4], [-1, 4], [23, 1]]

潜在方法

和上一个问题一样,第一个想到的解决方案是嵌套的for循环,但可能有一个使用pytorch函数的计算效率更高的解决方案。

和上一个问题一样,可能需要重塑以获得最后一个解决方案所需的形状。

因此,理想的解决方案可能是 [2, 4, -1, 4, 23, 1],它可以来自扁平化的索引列表 [ [1, 0, 1], [2, 0, 1], [1, 1, 1], [0, 0, 0], [2, 1, 0], [0, 1, 0] ]

但到目前为止,我还不知道有任何允许 3D 索引列表的 pytorch 函数。我一直在看 gatherindex_select

您可以专门使用 advanced indexing 整数数组索引

tensor_b = torch.tensor([[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]])
indices = torch.tensor([
           [[1, 0, 1], [2, 0, 1]],
           [[1, 1, 1], [0, 0, 0]],
           [[2, 1, 0], [0, 1, 0]]
])

result = tensor_b[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]

结果

tensor([[ 2,  4],
        [-1,  4],
        [23,  1]])