Pytorch 收集问题(3D 计算机视觉)
Pytorch gather question (3D Computer Vision)
我有N组C维点。每组有M个点。所以,有一个张量(N,M,C)。我们称之为功能。
我通过M维计算了最大元素和索引,求每个C维的最大点(一个max pooling操作),得到最大张量(N, 1, C)和索引张量(N, 1 , C).
我有另一个形状为(N, M, 3) 的张量,存储这N*M 个高维点的几何坐标。现在,我想使用每个 C 维度中最大点的索引,来获取所有这些最大点的坐标。
例如,N=2,M=4,C=6。
坐标张量,其形状为(2, 4, 3):
[[[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[8, 7, 6]]
[11, 12, 13]
[14, 15, 16]
[17, 18, 19]
[18, 17, 16]]]
指数张量,形状为(2, 1, 6):
[[[0, 1, 2, 1, 2, 3]]
[[1, 2, 3, 2, 1, 0]]]
比如indices中的第一个元素是0,我想从坐标张量中抓取[1,2,3]出来。对于第二个元素 (1),我想把 [4, 5, 6] 抓出来。对于下一个维度(3)中的第三个元素,我想把[18,17,16]抓出来
结果张量将如下所示:
[[[1, 2, 3] # 0
[4, 5, 6] # 1
[7, 8, 9] # 2
[4, 5, 6] # 1
[7, 8, 9] # 2
[8, 7, 6]] # 3
[[14, 15, 16] # 1
[17, 18, 19] # 2
[18, 17, 16] # 3
[17, 18, 19] # 2
[14, 15, 16] # 1
[11, 12, 13]]]# 0
它的形状是 (2, 6, 3).
我尝试使用 torch.gather 但无法正常工作。我写了一个简单的算法来枚举所有 N 个组,但它确实很慢,即使使用 TorchScript 的 JIT。那么,如何在pytorch中高效地编写这个呢?
您可以使用 integer array indexing combined with broadcasting semantics 得到结果。
import torch
x = torch.tensor([
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[8, 7, 6]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19],
[18, 17, 16]],
])
i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
[[1, 2, 3, 2, 1, 0]]])
# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)
# y is [2, 6, ...]
y = x[rows, cols]
我有N组C维点。每组有M个点。所以,有一个张量(N,M,C)。我们称之为功能。
我通过M维计算了最大元素和索引,求每个C维的最大点(一个max pooling操作),得到最大张量(N, 1, C)和索引张量(N, 1 , C).
我有另一个形状为(N, M, 3) 的张量,存储这N*M 个高维点的几何坐标。现在,我想使用每个 C 维度中最大点的索引,来获取所有这些最大点的坐标。
例如,N=2,M=4,C=6。
坐标张量,其形状为(2, 4, 3):
[[[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[8, 7, 6]]
[11, 12, 13]
[14, 15, 16]
[17, 18, 19]
[18, 17, 16]]]
指数张量,形状为(2, 1, 6):
[[[0, 1, 2, 1, 2, 3]]
[[1, 2, 3, 2, 1, 0]]]
比如indices中的第一个元素是0,我想从坐标张量中抓取[1,2,3]出来。对于第二个元素 (1),我想把 [4, 5, 6] 抓出来。对于下一个维度(3)中的第三个元素,我想把[18,17,16]抓出来
结果张量将如下所示:
[[[1, 2, 3] # 0
[4, 5, 6] # 1
[7, 8, 9] # 2
[4, 5, 6] # 1
[7, 8, 9] # 2
[8, 7, 6]] # 3
[[14, 15, 16] # 1
[17, 18, 19] # 2
[18, 17, 16] # 3
[17, 18, 19] # 2
[14, 15, 16] # 1
[11, 12, 13]]]# 0
它的形状是 (2, 6, 3).
我尝试使用 torch.gather 但无法正常工作。我写了一个简单的算法来枚举所有 N 个组,但它确实很慢,即使使用 TorchScript 的 JIT。那么,如何在pytorch中高效地编写这个呢?
您可以使用 integer array indexing combined with broadcasting semantics 得到结果。
import torch
x = torch.tensor([
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[8, 7, 6]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19],
[18, 17, 16]],
])
i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
[[1, 2, 3, 2, 1, 0]]])
# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)
# y is [2, 6, ...]
y = x[rows, cols]