如何在pytorch中连接张量?

How to concat a tensor in pytorch?

我想做的是这样的:

import torch 
a = torch.arange(120).reshape(2, 3, 4, 5)
b = torch.cat(list(a), dim=2)

我想知道:

  1. 我必须将张量转换为列表,这会不会导致性能不好?
  2. 连性能都可以,只用tensor可以吗?

您想:

  1. 减少副本数量:在这种特定情况下,由于我们正在重新排列底层数据的布局,因此需要进行副本。

  2. 减少或删除任何 torch.Tensor -> 非 torch.Tensor 转换:这将是使用 GPU 时的痛点,因为您要将数据传入和传出设备的。

您可以通过排列轴执行相同的操作,使 axis=0 转到 axis=-2(最后一个轴之前),然后展平最后两个轴:

>>> a.permute(1,2,0,3).flatten(-2)
tensor([[[  0,   1,   2,   3,   4,  60,  61,  62,  63,  64],
         [  5,   6,   7,   8,   9,  65,  66,  67,  68,  69],
         [ 10,  11,  12,  13,  14,  70,  71,  72,  73,  74],
         [ 15,  16,  17,  18,  19,  75,  76,  77,  78,  79]],

        [[ 20,  21,  22,  23,  24,  80,  81,  82,  83,  84],
         [ 25,  26,  27,  28,  29,  85,  86,  87,  88,  89],
         [ 30,  31,  32,  33,  34,  90,  91,  92,  93,  94],
         [ 35,  36,  37,  38,  39,  95,  96,  97,  98,  99]],

        [[ 40,  41,  42,  43,  44, 100, 101, 102, 103, 104],
         [ 45,  46,  47,  48,  49, 105, 106, 107, 108, 109],
         [ 50,  51,  52,  53,  54, 110, 111, 112, 113, 114],
         [ 55,  56,  57,  58,  59, 115, 116, 117, 118, 119]]])