Pytorch 收集问题(3D 计算机视觉)

Pytorch gather question (3D Computer Vision)

我有N组C维点。每组有M个点。所以,有一个张量(N,M,C)。我们称之为功能。

我通过M维计算了最大元素和索引,求每个C维的最大点(一个max pooling操作),得到最大张量(N, 1, C)和索引张量(N, 1 , C).

我有另一个形状为(N, M, 3) 的张量,存储这N*M 个高维点的几何坐标。现在,我想使用每个 C 维度中最大点的索引,来获取所有这些最大点的坐标。

例如,N=2,M=4,C=6。

坐标张量,其形状为(2, 4, 3):

[[[1, 2, 3]
  [4, 5, 6]
  [7, 8, 9]
  [8, 7, 6]]

 [11, 12, 13]
 [14, 15, 16]
 [17, 18, 19]
 [18, 17, 16]]]

指数张量,形状为(2, 1, 6):

[[[0, 1, 2, 1, 2, 3]]
 [[1, 2, 3, 2, 1, 0]]]

比如indices中的第一个元素是0,我想从坐标张量中抓取[1,2,3]出来。对于第二个元素 (1),我想把 [4, 5, 6] 抓出来。对于下一个维度(3)中的第三个元素,我想把[18,17,16]抓出来

结果张量将如下所示:

[[[1, 2, 3]  # 0
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [8, 7, 6]] # 3

 [[14, 15, 16] # 1
  [17, 18, 19] # 2
  [18, 17, 16] # 3
  [17, 18, 19] # 2
  [14, 15, 16] # 1
  [11, 12, 13]]]# 0

它的形状是 (2, 6, 3).

我尝试使用 torch.gather 但无法正常工作。我写了一个简单的算法来枚举所有 N 个组,但它确实很慢,即使使用 TorchScript 的 JIT。那么,如何在pytorch中高效地编写这个呢?

您可以使用 integer array indexing combined with broadcasting semantics 得到结果。

import torch

x = torch.tensor([
    [[1, 2, 3], 
     [4, 5, 6], 
     [7, 8, 9], 
     [8, 7, 6]],
    [[11, 12, 13],
     [14, 15, 16],
     [17, 18, 19],
     [18, 17, 16]],
])

i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
                  [[1, 2, 3, 2, 1, 0]]])

# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)

# y is [2, 6, ...]
y = x[rows, cols]