获取 pytorch 数据集的子集
Taking subsets of a pytorch dataset
我有一个网络,我想在某些数据集上进行训练(例如,假设 CIFAR10
)。我可以通过
创建数据加载器对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
我的问题如下:假设我想进行几次不同的训练迭代。假设我想首先在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络等等。为此,我需要能够访问这些图像。不幸的是,trainset
似乎不允许这样的访问。也就是说,尝试执行 trainset[:1000]
或更普遍的 trainset[mask]
将引发错误。
我可以做
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
然后
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改 trainset.train_data
,所以我需要重新定义 trainset
)。有什么办法可以避免吗?
理想情况下,我想要一些东西"equivalent"到
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
您可以为数据集加载器定义一个自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器)。
class YourSampler(Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler2, shuffle=False, num_workers=2)
PS:您可以在此处找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler
torch.utils.data.Subset
更简单,支持 shuffle
,并且不需要编写自己的采样器:
import torchvision
import torch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=None)
evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
shuffle=True, num_workers=2)
我有一个网络,我想在某些数据集上进行训练(例如,假设 CIFAR10
)。我可以通过
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
我的问题如下:假设我想进行几次不同的训练迭代。假设我想首先在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络等等。为此,我需要能够访问这些图像。不幸的是,trainset
似乎不允许这样的访问。也就是说,尝试执行 trainset[:1000]
或更普遍的 trainset[mask]
将引发错误。
我可以做
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
然后
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改 trainset.train_data
,所以我需要重新定义 trainset
)。有什么办法可以避免吗?
理想情况下,我想要一些东西"equivalent"到
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
您可以为数据集加载器定义一个自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器)。
class YourSampler(Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler = sampler2, shuffle=False, num_workers=2)
PS:您可以在此处找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler
torch.utils.data.Subset
更简单,支持 shuffle
,并且不需要编写自己的采样器:
import torchvision
import torch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=None)
evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
shuffle=True, num_workers=2)