使用 DataLoader 加载数据时跳过坏数据点
Skip bad data points when loading data using DataLoader
我正在尝试使用 mini-imagenet 数据集执行图像分类任务。我想使用的数据包含一些错误的数据点(我不确定为什么)。我想加载这些数据并在其上训练我的模型。在此过程中,我想完全跳过坏数据点。我该怎么做呢?
我使用的数据加载器如下:
class MiniImageNet(Dataset):
def __init__(self, root, train=True,
transform=None,
index_path=None, index=None, base_sess=None):
if train:
setname = 'train'
else:
setname = 'test'
self.root = os.path.expanduser(root)
self.transform = transform
self.train = train # training set or test set
self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')
csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
self.data = []
self.targets = []
self.data2label = {}
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(self.IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
self.data.append(path)
self.targets.append(lb)
self.data2label[path] = lb
self.y = self.targets
if train:
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([image_size, image_size]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, targets = self.data[i], self.targets[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, targets
我尝试使用 try-except 序列,但在那种情况下,数据加载器没有跳过,而是返回 None,导致错误。如何完全跳过数据加载器中的数据点?
尝试删除 __init__
函数末尾的错误数据。
for i in range(len(self.data) - 1, -1, -1):
if is_bad_data(self.data[i], self.targets[i]):
del self.data[i]
del self.targets[i]
我正在尝试使用 mini-imagenet 数据集执行图像分类任务。我想使用的数据包含一些错误的数据点(我不确定为什么)。我想加载这些数据并在其上训练我的模型。在此过程中,我想完全跳过坏数据点。我该怎么做呢? 我使用的数据加载器如下:
class MiniImageNet(Dataset):
def __init__(self, root, train=True,
transform=None,
index_path=None, index=None, base_sess=None):
if train:
setname = 'train'
else:
setname = 'test'
self.root = os.path.expanduser(root)
self.transform = transform
self.train = train # training set or test set
self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')
csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
self.data = []
self.targets = []
self.data2label = {}
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(self.IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
self.data.append(path)
self.targets.append(lb)
self.data2label[path] = lb
self.y = self.targets
if train:
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([image_size, image_size]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, targets = self.data[i], self.targets[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, targets
我尝试使用 try-except 序列,但在那种情况下,数据加载器没有跳过,而是返回 None,导致错误。如何完全跳过数据加载器中的数据点?
尝试删除 __init__
函数末尾的错误数据。
for i in range(len(self.data) - 1, -1, -1):
if is_bad_data(self.data[i], self.targets[i]):
del self.data[i]
del self.targets[i]