torch中使用random_split后如何获取train_dataset的路径名

How to obtain the path name of train_dataset after using random_split in torch

我有以下代码:

import torch, torchvision
root_dataset ="./data"
dataset = torchvision.datasets.folder.ImageFolder(root=root_dataset, transform=None, target_transform=None)
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
    dataset=dataset,
    lengths=[num_train, num_valid]
)

我的问题是:

torch中使用random_split后如何获取train_dataset路径的名称列表?

谢谢。

路径(和标签)存储在 dataset.imgs 中。例如,对于 imagenet:

In [ ]: print(dataset.imgs[0])
Out [ ]: ('/shareDB/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG', 0) 

拆分数据集后,每次拆分指向原始数据集:

In [ ]: len(train_dataset.dataset), len(valid_dataset.dataset)
Out [ ]: (50000, 50000)

但是,每个拆分还包含为拆分选择的原始数据集中的样本索引。您可以使用这些索引和原始数据集来获取为每个拆分选择的图像列表:

valid_imgs = [valid_dataset.dataset.imgs[i_] for i_ in valid_dataset.indices]
train_imgs = [train_dataset.dataset.imgs[i_] for i_ in train_dataset.indices]