在 PyTorch 中避免在批量张量中选择维度的循环

Avoid Loop for Selecting Dimensions in a Batch Tensor in PyTorch

我有一个批量张量和另一个张量,其维度索引为批量张量 select。目前,我正在循环批处理张量,如下面的代码片段所示:

import torch

# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)

# notice that channels_id has 8 elements, i.e., = batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])

这就是我在 for 循环中 select 维度然后堆叠以转换单个张量的方式:

batch_out = torch.stack([batch_i[channel_i] for batch_i, channel_i in zip(batch_data, channels_id)])
batch_out.size()  # prints torch.Size([8, 240, 320])

它工作正常。但是,是否有更好的 PyTorch 方法来实现相同的目标?

根据 @Shai, I could make it work using the torch.gather 函数的提示。下面是完整的代码:

import torch

# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)

# notice that channels_id has 8 elements, i.e., batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])

# resizing channels_id to (8 , 1, 240, 320)
channels_id = channels_id.view(-1, 1, 1, 1).repeat((1, 1) + batch_data.size()[-2:])

batch_out = torch.gather(batch_data, 1, channels_id).squeeze()
batch_out.size()  # prints torch.Size([8, 240, 320])