具有 4D 图像的数据集:预期为 Byte 但发现为 Float

Dataset with 4D images: expected Byte but found Float

我有一些 MRI 扫描,我想从中创建自定义 PyTorch Dataset。每次扫描都是一组 31 张 RGB 图像,因此扫描是 4 维的 (Channels, Depth, Height, Width)。图片是.png,每次扫描是一个包含31张图片的文件夹。加载扫描后,我尝试通过 Conv3D 传递它们,但出现错误(末尾的完整回溯):

x = torch.unsqueeze(dataset[0][0], 0)
x.shape  # torch.Size([1, 3, 31, 512, 512])
m = nn.Conv3d(3,12,3)
out = m(x)

RuntimeError: expected scalar type Byte but found Float

如何解决这个错误?我认为发生这种情况是因为我将扫描作为 NumPy 数组的 NumPy 数组加载,但我不知道该怎么做。如何将 4D 图像数据加载到自定义 Dataset?

这是我的习惯 Dataset class:

import torch
import os
import pandas as pd
from skimage import io
from torch.utils.data import Dataset

class TrainImages(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        # The folder containing the images of a scan
        img_path = os.path.join(self.root_dir, str(self.annotations.iloc[index, 0]).zfill(5))
        # Create a tensor out of a numpy array of numpy arrays, where each array is an image in the scan
        image = torch.from_numpy(np.array([np.array(Image.open(os.path.join(str(img_path),"rgb-"+str(i)+".png"))) for i in range(31)]).transpose(3,0,1,2).astype(np.uint8))
        y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
        return (image, y_label)

完整追溯:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-29-f3c4dfbd5496> in <module>
      1 m=nn.Conv3d(3,12,3)
----> 2 out=m(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    571                             self.dilation, self.groups)
    572         return F.conv3d(input, self.weight, self.bias, self.stride,
--> 573                         self.padding, self.dilation, self.groups)
    574 
    575 

RuntimeError: expected scalar type Byte but found Float

错误消息可能令人困惑,但问题是您的数据具有 Byte 类型,而 conv3d 需要 Float。您需要将 Dataset__getitem__(...) 中的 np.uint8 更改为 np.float32:

image = torch.from_numpy(np.array([
    np.array(Image.open(os.path.join(str(img_path),"rgb-"+str(i)+".png")))
    for i in range(31)
]).transpose(3, 0, 1, 2).astype(np.float32))  # <<< changed from np.uint8 to float32

或者,在传递给模型之前将 x 转换为 Float

out = m(x.float())

请注意,如果您以后使用.ToTensor()这样的转换,此问题也会得到解决。