Pytorch,使用多个索引从 3D 张量中检索值。计算效率最高的解决方案
Pytorch, retrieving values from a 3D tensor using several indices. Most computationally efficient solution
相关:
这是另一个关于使用索引列表从 3D 张量中检索值的问题。
在这种情况下,我有一个 3d 张量,例如
b = [[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]]
tensor_b = torch.tensor(b)
tensor_b
tensor([[[ 4, 20],
[ 1, -1]],
[[ 1, 2],
[ 8, -1]],
[[92, 4],
[23, -1]]])
在本例中,我有一个 3D 索引列表。所以
indices = [
[[1, 0, 1], [2, 0, 1]],
[[1, 1, 1], [0, 0, 0]],
[[2, 1, 0], [0, 1, 0]]
]
每个三元组都是张量 b 的索引。期望的结果是
[[2, 4], [-1, 4], [23, 1]]
潜在方法
和上一个问题一样,第一个想到的解决方案是嵌套的for循环,但可能有一个使用pytorch函数的计算效率更高的解决方案。
和上一个问题一样,可能需要重塑以获得最后一个解决方案所需的形状。
因此,理想的解决方案可能是 [2, 4, -1, 4, 23, 1]
,它可以来自扁平化的索引列表
[ [1, 0, 1], [2, 0, 1], [1, 1, 1], [0, 0, 0], [2, 1, 0], [0, 1, 0] ]
但到目前为止,我还不知道有任何允许 3D 索引列表的 pytorch 函数。我一直在看 gather
和 index_select
。
您可以专门使用 advanced indexing 整数数组索引
tensor_b = torch.tensor([[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]])
indices = torch.tensor([
[[1, 0, 1], [2, 0, 1]],
[[1, 1, 1], [0, 0, 0]],
[[2, 1, 0], [0, 1, 0]]
])
result = tensor_b[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]
结果
tensor([[ 2, 4],
[-1, 4],
[23, 1]])
相关:
这是另一个关于使用索引列表从 3D 张量中检索值的问题。
在这种情况下,我有一个 3d 张量,例如
b = [[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]]
tensor_b = torch.tensor(b)
tensor_b
tensor([[[ 4, 20],
[ 1, -1]],
[[ 1, 2],
[ 8, -1]],
[[92, 4],
[23, -1]]])
在本例中,我有一个 3D 索引列表。所以
indices = [
[[1, 0, 1], [2, 0, 1]],
[[1, 1, 1], [0, 0, 0]],
[[2, 1, 0], [0, 1, 0]]
]
每个三元组都是张量 b 的索引。期望的结果是
[[2, 4], [-1, 4], [23, 1]]
潜在方法
和上一个问题一样,第一个想到的解决方案是嵌套的for循环,但可能有一个使用pytorch函数的计算效率更高的解决方案。
和上一个问题一样,可能需要重塑以获得最后一个解决方案所需的形状。
因此,理想的解决方案可能是 [2, 4, -1, 4, 23, 1]
,它可以来自扁平化的索引列表
[ [1, 0, 1], [2, 0, 1], [1, 1, 1], [0, 0, 0], [2, 1, 0], [0, 1, 0] ]
但到目前为止,我还不知道有任何允许 3D 索引列表的 pytorch 函数。我一直在看 gather
和 index_select
。
您可以专门使用 advanced indexing 整数数组索引
tensor_b = torch.tensor([[[4, 20], [1, -1]], [[1, 2], [8, -1]], [[92, 4], [23, -1]]])
indices = torch.tensor([
[[1, 0, 1], [2, 0, 1]],
[[1, 1, 1], [0, 0, 0]],
[[2, 1, 0], [0, 1, 0]]
])
result = tensor_b[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]
结果
tensor([[ 2, 4],
[-1, 4],
[23, 1]])