Pytorch Custom dataloader: TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

Pytorch Custom dataloader: TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

我想使用自定义数据加载器将 numpy 文件传输到数据加载器。当我设置 transorm 时,我收到错误 TypeError: pic should be PIL Image or ndarray。得到了

import os
import torch
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import transforms

class CustomTensorDataset(Dataset):
    """
    TensorDataset with support for transforms
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)

te_data    =  torch.FloatTensor(np.ones([100, 3, 32, 32]))
te_targets =  torch.FloatTensor(np.ones([100]))

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=transform)
# testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=None) # --> no error

for item in testset_custom:
    print(item)

Dataset 的输入数据需要是 PIL 图像或 numpy 数组。但是,您的 te_datate_targetstorch.tensor。要解决这个问题,请不要将它们转换为 torch.tensor,然后再提供给 Dataset 并保持其维度。数据集,本身改变了它的维度:

te_data    =  np.ones([100, 32, 32, 3])
te_targets =  np.ones([100])

并且只要输入是 numpy 数组,assert 的条件也需要更改:

assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)