在 PyTorch 中避免在批量张量中选择维度的循环
Avoid Loop for Selecting Dimensions in a Batch Tensor in PyTorch
我有一个批量张量和另一个张量,其维度索引为批量张量 select。目前,我正在循环批处理张量,如下面的代码片段所示:
import torch
# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)
# notice that channels_id has 8 elements, i.e., = batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])
这就是我在 for 循环中 select 维度然后堆叠以转换单个张量的方式:
batch_out = torch.stack([batch_i[channel_i] for batch_i, channel_i in zip(batch_data, channels_id)])
batch_out.size() # prints torch.Size([8, 240, 320])
它工作正常。但是,是否有更好的 PyTorch 方法来实现相同的目标?
根据 @Shai, I could make it work using the torch.gather 函数的提示。下面是完整的代码:
import torch
# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)
# notice that channels_id has 8 elements, i.e., batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])
# resizing channels_id to (8 , 1, 240, 320)
channels_id = channels_id.view(-1, 1, 1, 1).repeat((1, 1) + batch_data.size()[-2:])
batch_out = torch.gather(batch_data, 1, channels_id).squeeze()
batch_out.size() # prints torch.Size([8, 240, 320])
我有一个批量张量和另一个张量,其维度索引为批量张量 select。目前,我正在循环批处理张量,如下面的代码片段所示:
import torch
# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)
# notice that channels_id has 8 elements, i.e., = batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])
这就是我在 for 循环中 select 维度然后堆叠以转换单个张量的方式:
batch_out = torch.stack([batch_i[channel_i] for batch_i, channel_i in zip(batch_data, channels_id)])
batch_out.size() # prints torch.Size([8, 240, 320])
它工作正常。但是,是否有更好的 PyTorch 方法来实现相同的目标?
根据 @Shai, I could make it work using the torch.gather 函数的提示。下面是完整的代码:
import torch
# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)
# notice that channels_id has 8 elements, i.e., batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])
# resizing channels_id to (8 , 1, 240, 320)
channels_id = channels_id.view(-1, 1, 1, 1).repeat((1, 1) + batch_data.size()[-2:])
batch_out = torch.gather(batch_data, 1, channels_id).squeeze()
batch_out.size() # prints torch.Size([8, 240, 320])