在 pytorch 中连接数据集

concat datasets in pytorch

我的文件夹中有一些数据集,我正在使用连接数据集连接它们。所以,我有这样的数据文件夹(注意文件夹 1 和 2 只有 1 class 而不是 2):

-表示子文件夹

folder0
-cats
-dogs

folder1
-cats

folder2
-cats

folder3
-dogs

然后我这样做:

    trainset1 = datasets.ImageFolder(folder0, loader=my_loader, transform=SomeAug())    
    trainset2 = datasets.ImageFolder(folder1, loader=my_loader, transform=SomeAug())    
    trainset3 = datasets.ImageFolder(folder2, loader=my_loader, transform=SomeAug())    
    trainset = torch.utils.data.ConcatDataset([trainset1, trainset2, trainset3])

这是这样做的合法方式吗?当我通过以下方式查看总图像时:

len(train_loader.dataset))

它加起来正确。

然而,当我这样做时:

print(trainset.classes)

它让我:

AttributeError: 'ConcatDataset' object has no attribute 'classes'

当我只使用一个数据集时它不会。

我只是想确保在使用 concat 数据集方法时没有问题。

ImageFolder 继承自 DatasetFolder,后者有一个 class 方法 find_classes,在构造函数中调用该方法来初始化变量 DatasetFolder.classes。因此,您可以毫无错误地调用 trainset.classes

但是,ConcatDataset 并没有继承自 ImageFolder,更普遍的是默认情况下不会实现 classes 变量。一般来说,这样做会很困难,因为 ImageFolder 查找 classes 的方法依赖于特定的文件结构,而 ConcatDataset 不假定这样的文件结构,因此它可以使用一组更通用的数据集。

如果此功能对您来说必不可少,您可以编写一个简单的数据集类型,该类型继承自 ConcatDataset,特别需要 ImageFolder 个数据集,并将 classes 作为来自每个组成数据集的可能 classes。