将 4 维张量转换为列表列表的列表 (Python)

Converting 4 dimensional tensors into list of lists of lists (Python)

我有 6 个形状为 (batch_size, S, S, 1) 的张量,我想将它们组合成一个 python 大小列表 (batch_size, S*S , 6) - 所以张量的每个元素都应该在内部列表中。

不用循环也能实现吗?有什么有效的解决方法?

batch_size=10S=4 为了这个例子的目的:

 >>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]

事实上,第一步是在最后一个维度上连接张量 axis=3:

>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])

然后重塑以展平 axis=1axis=2,您可以在此处使用 torch.flatten 这样做,因为两个轴相邻:

>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])