为什么 torch.nn.Upsample return 是垃圾图片?

Why does torch.nn.Upsample return a junk image?

当我执行下面的代码段时,nn.Upsample 似乎完全破坏了我的形象。我是不是应用错了?

import torch
import imageio
import torch.nn as nn
from matplotlib import pyplot as plt

small = imageio.imread('small.png')                               # shape 200, 390, 4
small_reshaped = small.reshape(4, 200, 390)                       # shape 4, 200, 390
batch = torch.as_tensor(small_reshaped).unsqueeze(0)              # shape 1, 4, 200, 390
ups = nn.Upsample((500, 970))
upsampled_batch = ups(batch)                                      # shape 1, 4, 500, 970
upsampled_small = upsampled_batch[0].reshape(500, 970, 4)         # shape 500, 970, 4
plt.imshow(small)
plt.imshow(upsampled_small)
plt.show()

上采样前:

上采样后:

原图(small.png):

已解决。重塑破坏了形象。我应该换位。 有关详细信息,请参阅 https://discuss.pytorch.org/t/for-beginners-do-not-use-view-or-reshape-to-swap-dimensions-of-tensors/75524

可行的解决方案:

...
small_reshaped = small.transpose(2, 0, 1)                          # shape 4, 200, 390
...
upsampled_small = upsampled_batch[0].transpose(0,1).transpose(1,2) # shape 500, 970, 4
...