在 PyTorch 中将两个 torchvision.dataset 对象组合成一个 DataLoader

Combing two torchvision.dataset objects into a single DataLoader in PyTorch

我正在 PyTorch 中的 Cifar-10 数据集上训练 GANS(因此不需要 train/val/test 拆分),我希望能够在片段中组合 torchvision.datasets.CIFAR10下面形成一个单一的 torch.utils.data.DataLoader 迭代器。我目前的解决方案是这样的:

import torchvision
import torch
batch_size = 128
cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
cifar_dl1 = torch.utils.data.DataLoader(cifar_trainset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)
cifar_dl2 = torch.utils.data.DataLoader(cifar_testset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)

然后在我的训练循环中我有类似的东西:

for dl in [cifar_dl1, cifar_l2]:
   for data in dl:
      # training

在多线程上下文中,这种方法的问题是,我发现我的设置和这个任务的最佳工作人员数量是 12,现在我声明总共有 24 个工作人员,这显然也是许多,更不用说与重新迭代每个数据加载器相关的启动时间成本,尽管每个数据加载器都有持久性工作标志的好处。

非常感谢任何对此问题的解决方案。

您可以使用 torch.utils.data 模块中的 ConcatDataset

代码段:

import torch    
import torchvision

batch_size = 128

cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)

cifar_dataset = torch.utils.data.ConcatDataset([cifar_trainset, cifar_testset])

cifar_dataloader = torch.utils.data.DataLoader(cifar_dataset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)

for data in cifar_dataloader:
    # training