Pytorch 排列不改变所需的索引

Pytorch permute not changing desired index

我正在尝试使用 permute 函数来交换张量的轴,但由于某种原因,输出与预期不符。代码的输出是 torch.Size([512, 256, 3, 3]),但我希望它是 torch.Size([256, 512, 3, 3])。看起来我不能使用翻转来切换 0、1 索引。有什么我想念的吗?我希望更改张量,使其形状为 (256, 512, 3, 3)。

可重现代码:

import torch

wtf = torch.rand(3, 3, 512, 256)
wtf = wtf.permute(2, 3, 1, 0)
print(wtf.shape)

提供给 torch.permute 的数字是按照您希望新张量具有的顺序排列的轴索引。

已将 x 设置为 torch.rand(3, 3, 512, 256)

  • 如果你想反转轴的顺序:初始顺序是0, 1, 2, 3,你想要3, 2, 1, 0:

    >>> wtf.permute(3, 2, 1, 0).shape
    torch.Size([256, 512, 3, 3])
    

    反转轴顺序本质上是转置操作:

    >>> wtf.T
    torch.Size([256, 512, 3, 3])
    
  • 如果你只是想反转 保持最后两个的顺序:原始顺序是 0, 1, 2, 3 结果顺序是 3, 2, 0, 1:

    >>> x.permute(3, 2, 0, 1).shape
    torch.Size([256, 512, 3, 3])
    

两个选项的区别在于最后两个大小为 3 的轴将被交换。