用其他运算符替换 torch.gather?
Replace torch.gather by other operator?
我有一个脚本代码,其中 x1
和 x2
大小为 1x68x8x8
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
至于上面的代码,当我使用旧的 onnx
时,我遇到了 torch.gather
的问题。因此,我想找到一种替代解决方案,用其他运算符替换 toch.gather
但与上述代码提供相同的输出。你能给我一些建议吗?
一种解决方法是使用等效的 numpy 方法。如果您在某处包含 import numpy as np
语句,您可以执行以下操作。
outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))
如果这给你一个与毕业相关的错误,试试
outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))
一种没有 numpy 的方法:在这种情况下,max_ids
似乎每行只包含一个条目。因此,我相信以下方法会起作用:
max_ids = torch.argmax(x1, 1) # do not reshape
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]
我有一个脚本代码,其中 x1
和 x2
大小为 1x68x8x8
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
至于上面的代码,当我使用旧的 onnx
时,我遇到了 torch.gather
的问题。因此,我想找到一种替代解决方案,用其他运算符替换 toch.gather
但与上述代码提供相同的输出。你能给我一些建议吗?
一种解决方法是使用等效的 numpy 方法。如果您在某处包含 import numpy as np
语句,您可以执行以下操作。
outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))
如果这给你一个与毕业相关的错误,试试
outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))
一种没有 numpy 的方法:在这种情况下,max_ids
似乎每行只包含一个条目。因此,我相信以下方法会起作用:
max_ids = torch.argmax(x1, 1) # do not reshape
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]