用其他运算符替换 torch.gather?

Replace torch.gather by other operator?

我有一个脚本代码,其中 x1x2 大小为 1x68x8x8

 tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
 x1 = x1.view(tmp_batch*tmp_channel, -1)        
 max_ids = torch.argmax(x1, 1)            
 max_ids = max_ids.view(-1, 1)
            
 x2 = x2.view(tmp_batch*tmp_channel, -1)
 outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1

至于上面的代码,当我使用旧的 onnx 时,我遇到了 torch.gather 的问题。因此,我想找到一种替代解决方案,用其他运算符替换 toch.gather 但与上述代码提供相同的输出。你能给我一些建议吗?

一种解决方法是使用等效的 numpy 方法。如果您在某处包含 import numpy as np 语句,您可以执行以下操作。

outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))

如果这给你一个与毕业相关的错误,试试

outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))

一种没有 numpy 的方法:在这种情况下,max_ids 似乎每行只包含一个条目。因此,我相信以下方法会起作用:

max_ids = torch.argmax(x1, 1) # do not reshape
            
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]