如何使用 PyTorch 在自定义图像数据集中创建 train-val 拆分?

How to create a train-val split in custom image datasets using PyTorch?

我想从我的原始训练集中创建一个训练+验证集。该目录分为训练和测试。我加载原始训练集并想将其拆分为训练集和验证集,以便我可以使用 train_loader[=19 评估训练期间的验证损失=].

没有很多关于此的文档可以清楚地解释事情。

查看答案 here

我也把它贴在下面了。

============================================= =========

使用ImageFolder读取数据。任务是二值图像分类,数据集中有 498 张图像,它们平均分布在 类(每个 249 张图像)中。

img_dataset = ImageFolder(..., transforms=t)

1。 SubsetRandomSampler

dataset_size = len(img_dataset)
dataset_indices = list(range(dataset_size))

np.random.shuffle(dataset_indices)

val_split_index = int(np.floor(0.2 * dataset_size))

train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]

train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)


train_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=8, sampler=train_sampler)
validation_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=1, sampler=val_sampler)

2。 random_split

这里,在总共 498 张图像中,随机分配 400 张用于训练,其余 98 张用于验证。

dataset_train, dataset_valid = random_split(img_dataset, (400, 98))

train_loader = DataLoader(dataset=dataset_train, shuffle=True, batch_size=8)
val_loader = DataLoader(dataset=dataset_valid, shuffle=False, batch_size=1)

3。 WeightedRandomSampler

if someone stumbled here searching for WeightedRandomSampler, check @ptrblck's answer here for a good explanation of what is happening below.

现在,WeightedRandomSampler 如何适合创建 train+val 集?因为与 SubsetRandomSamplerrandom_split() 不同,我们不会在这里拆分 train 和 val。我们只是确保每个批次在训练期间获得相同数量的 类。

所以,我猜我们需要在 random_split()SubsetRandomSampler 之后使用WeightedRandomSampler 。但这并不能确保 train 和 val 在 类 之间具有相似的比率。

target_list = []

for _, t in imgdataset:
    target_list.append(t)

target_list = torch.tensor(target_list)
target_list = target_list[torch.randperm(len(target_list))]

# get_class_distribution() is a function that takes in a dataset and 
# returns a dictionary with class count. In this case, the 
# get_class_distribution(img_dataset)  returns the following - 
# {'class_0': 249, 'class_0': 249}
class_count = [i for i in get_class_distribution(img_dataset).values()]
class_weights = 1./torch.tensor(class_count, dtype=torch.float) 

class_weights_all = class_weights[target_list]

weighted_sampler = WeightedRandomSampler(
    weights=class_weights_all,
    num_samples=len(class_weights_all),
    replacement=True
)