HDF5 数据加载非常慢。导致 GPU 的波动率为零
HDF5 dataloading very slow. Causes zero % volatility in GPUs
我正在使用自定义 PyTorch 数据类从我创建的 H5 数据集中加载实例。但是,加载样本时它似乎非常慢。我遵循了一些关于处理大型 HDF5 数据集的建议,但我想知道我是否在做一些明显错误的事情。如果这有所作为,我将在 Linux 上部署我的代码。我是 运行 4 个 GPU 上的代码,nn.dataparallel 适合我的模型。由于数据加载非常缓慢,因此 GPU 波动率为 0%。这是我的数据类加载器:
import h5py
from torch.utils import data
class Features_Dataset(data.Dataset):
def __init__(self, archive, phase):
self.archive = archive
self.phase = phase
def __getitem__(self, index):
with h5py.File(self.archive, 'r', libver='latest', swmr=True) as archive:
datum = archive[str(self.phase) + '_all_arrays'][index]
label = archive[str(self.phase) + '_labels'][index]
path = archive[str(self.phase) + '_img_paths'][index]
return datum, label, path
def __len__(self):
with h5py.File(self.archive, 'r', libver='latest', swmr=True) as archive:
datum = archive[str(self.phase) + '_all_arrays']
return len(datum)
if __name__ == '__main__':
train_dataset = Features_Dataset(archive= "featuresdata/train.hdf5", phase= 'train')
trainloader = data.DataLoader(train_dataset, num_workers=8, batch_size=128)
print(len(trainloader))
for i, (data, label, path) in enumerate(trainloader):
print(path)
我是不是遗漏了什么明显的东西?有没有更好的快速加载实例的方法?
编辑:
这是更新后的数据类,但是我现在在尝试使用多处理时遇到 picling 错误。
import h5py
from torch.utils import data
import torch.multiprocessing as mp
mp.set_start_method('spawn')
class Features_Dataset(data.Dataset):
def __init__(self, archive, phase):
self.archive = h5py.File(archive, 'r')
self.labels = self.archive[str(phase) + '_labels']
self.data = self.archive[str(phase) + '_all_arrays']
self.img_paths = self.archive[str(phase) + '_img_paths']
def __getitem__(self, index):
datum = self.data[index]
label = self.labels[index]
path = self.img_paths[index]
return datum, label, path
def __len__(self):
return len(self.data)
def close(self):
self.archive.close()
if __name__ == '__main__':
train_dataset = Features_Dataset(archive= "featuresdata/train.hdf5", phase= 'train')
trainloader = data.DataLoader(train_dataset, num_workers=2, batch_size=4)
print(len(trainloader))
for i, (data, label, path) in enumerate(trainloader):
print(path)
你能不能在初始化中不打开文件一次并存储文件处理程序?目前,当您调用 get item 或 len 时,您将始终在每次调用时打开文件。
我正在使用自定义 PyTorch 数据类从我创建的 H5 数据集中加载实例。但是,加载样本时它似乎非常慢。我遵循了一些关于处理大型 HDF5 数据集的建议,但我想知道我是否在做一些明显错误的事情。如果这有所作为,我将在 Linux 上部署我的代码。我是 运行 4 个 GPU 上的代码,nn.dataparallel 适合我的模型。由于数据加载非常缓慢,因此 GPU 波动率为 0%。这是我的数据类加载器:
import h5py
from torch.utils import data
class Features_Dataset(data.Dataset):
def __init__(self, archive, phase):
self.archive = archive
self.phase = phase
def __getitem__(self, index):
with h5py.File(self.archive, 'r', libver='latest', swmr=True) as archive:
datum = archive[str(self.phase) + '_all_arrays'][index]
label = archive[str(self.phase) + '_labels'][index]
path = archive[str(self.phase) + '_img_paths'][index]
return datum, label, path
def __len__(self):
with h5py.File(self.archive, 'r', libver='latest', swmr=True) as archive:
datum = archive[str(self.phase) + '_all_arrays']
return len(datum)
if __name__ == '__main__':
train_dataset = Features_Dataset(archive= "featuresdata/train.hdf5", phase= 'train')
trainloader = data.DataLoader(train_dataset, num_workers=8, batch_size=128)
print(len(trainloader))
for i, (data, label, path) in enumerate(trainloader):
print(path)
我是不是遗漏了什么明显的东西?有没有更好的快速加载实例的方法?
编辑:
这是更新后的数据类,但是我现在在尝试使用多处理时遇到 picling 错误。
import h5py
from torch.utils import data
import torch.multiprocessing as mp
mp.set_start_method('spawn')
class Features_Dataset(data.Dataset):
def __init__(self, archive, phase):
self.archive = h5py.File(archive, 'r')
self.labels = self.archive[str(phase) + '_labels']
self.data = self.archive[str(phase) + '_all_arrays']
self.img_paths = self.archive[str(phase) + '_img_paths']
def __getitem__(self, index):
datum = self.data[index]
label = self.labels[index]
path = self.img_paths[index]
return datum, label, path
def __len__(self):
return len(self.data)
def close(self):
self.archive.close()
if __name__ == '__main__':
train_dataset = Features_Dataset(archive= "featuresdata/train.hdf5", phase= 'train')
trainloader = data.DataLoader(train_dataset, num_workers=2, batch_size=4)
print(len(trainloader))
for i, (data, label, path) in enumerate(trainloader):
print(path)
你能不能在初始化中不打开文件一次并存储文件处理程序?目前,当您调用 get item 或 len 时,您将始终在每次调用时打开文件。