如何连接特定轴上的张量列表?

How to concatenate a list of tensors on a specific axis?

我有一个形状相同的张量列表 (my_list)。我想将它们连接到通道轴上。 帮助代码

for i in my_list:
    print(i.shape) #[1, 3, 128, 128] => [batch, channel, width, height]

我想得到一个新的张量,即 new_tensor = [1, 3*len(my_list), width, height]

我不想使用 torch.stack() 添加新维度。我无法弄清楚如何使用 torch.cat() 来做到这一点?

举一个例子 list 包含 10 张量 (1, 3, 128, 128):

>>> my_list = [torch.rand(1, 3, 128, 128) for _ in range(10)]

您希望在 axis=1 上连接您的张量,因为第二维是将张量连接在一起的地方。您可以使用 torch.cat:

>>> res = torch.cat(my_list, axis=1)
>>> res.shape
torch.Size([1, 30, 128, 128])

这实际上相当于在 my_list 垂直堆叠你的张量, 使用 torch.vstack:

>>> res = torch.vstack(my_list)