在 PyTorch 的数据加载器中使用带有腌制数据的生成器
Using a generator with pickled data in a Dataloader for PyTorch
我之前做过一些预处理和特征选择,我有一个由列表列表组成的pickle训练输入数据,例如(但腌制)
[[1,5,45,13], [23,256,4,2], [1,12,88,78], [-1]]
[[12,45,77,325], [23,257,5,28], [3,7,48,178], [12,77,89,99]]
[[13,22,78,89], [12,33,97], [-1], [-1]]
[-1]
是一个填充标记,但我认为这不重要。
因为文件有好几千兆字节,所以我想节省内存并使用生成器逐行(逐个列表)读取pickle。我已经发现 this answer 可能会有帮助。这就像:
def yield_from_pickle(pfin):
with open(pfin, 'rb') as fhin:
while True:
try:
yield pickle.load(fhin)
except EOFError:
break
接下来是,我希望在 PyTorch (1.0.1) Dataloader. From what I found in other answers, I must feed it a Dataset 中使用这些数据,您可以对其进行子集化,但必须包含 __len__
和 __getitem__
.它可能是这样的:
class TextDataset(Dataset):
def __init__(self, pfin):
self.pfin = pfin
def __len__(self):
# memory-lenient way but exhaust generator?
return sum(1 for _ in self.yield_from_pickle())
def __getitem__(self, index):
# ???
pass
def yield_from_pickle(self):
with open(self.pfin, 'rb') as fhin:
while True:
try:
yield pickle.load(fhin)
except EOFError:
break
但我完全不确定这是否可能。如何以合理的方式实施 __len__
和 __getitem__
?我不认为我用 __len__
做的是个好主意,因为这会耗尽生成器,而且我完全不知道如何在保留生成器的同时安全地实现 __getitem__
。
有没有更好的方法?总而言之,我想构建一个可以提供给 PyTorch 的数据加载器的数据集(因为它具有多处理能力),但以一种内存高效的方式,我不必将整个文件读入内存。
请参阅 my other answer 了解您的选择。
简而言之,要么将每个样本预处理成单独的文件,要么使用不需要完全加载到内存中读取的数据格式。
我之前做过一些预处理和特征选择,我有一个由列表列表组成的pickle训练输入数据,例如(但腌制)
[[1,5,45,13], [23,256,4,2], [1,12,88,78], [-1]]
[[12,45,77,325], [23,257,5,28], [3,7,48,178], [12,77,89,99]]
[[13,22,78,89], [12,33,97], [-1], [-1]]
[-1]
是一个填充标记,但我认为这不重要。
因为文件有好几千兆字节,所以我想节省内存并使用生成器逐行(逐个列表)读取pickle。我已经发现 this answer 可能会有帮助。这就像:
def yield_from_pickle(pfin):
with open(pfin, 'rb') as fhin:
while True:
try:
yield pickle.load(fhin)
except EOFError:
break
接下来是,我希望在 PyTorch (1.0.1) Dataloader. From what I found in other answers, I must feed it a Dataset 中使用这些数据,您可以对其进行子集化,但必须包含 __len__
和 __getitem__
.它可能是这样的:
class TextDataset(Dataset):
def __init__(self, pfin):
self.pfin = pfin
def __len__(self):
# memory-lenient way but exhaust generator?
return sum(1 for _ in self.yield_from_pickle())
def __getitem__(self, index):
# ???
pass
def yield_from_pickle(self):
with open(self.pfin, 'rb') as fhin:
while True:
try:
yield pickle.load(fhin)
except EOFError:
break
但我完全不确定这是否可能。如何以合理的方式实施 __len__
和 __getitem__
?我不认为我用 __len__
做的是个好主意,因为这会耗尽生成器,而且我完全不知道如何在保留生成器的同时安全地实现 __getitem__
。
有没有更好的方法?总而言之,我想构建一个可以提供给 PyTorch 的数据加载器的数据集(因为它具有多处理能力),但以一种内存高效的方式,我不必将整个文件读入内存。
请参阅 my other answer 了解您的选择。
简而言之,要么将每个样本预处理成单独的文件,要么使用不需要完全加载到内存中读取的数据格式。