如何在 tf.data.Dataset 对象上使用 sequence/generator 将部分数据放入内存?
How to use sequence/generator on tf.data.Dataset object to fit partial data into memory?
我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数 (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory) 加载图像,其中 return 是一个 tf.data.Dataset 对象:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical")
我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我正在使用 Google Colab 并且可以看到 RAM 使用量在训练期间增加第一个纪元)。
然后我尝试使用Keras Sequence,这是将部分数据加载到RAM的建议解决方案(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):
class DatasetGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return tf.data.experimental.cardinality(self.dataset).numpy()
def __getitem__(self, idx):
return list(self.dataset.as_numpy_iterator())[idx]
我用以下方法训练模型:
history = model.fit(DatasetGenerator(train_ds), ...)
问题是 getitem() 必须 return 一批有索引的数据。但是,我使用的 list() 函数必须将整个数据集放入 RAM 中,因此在 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。
我的问题:
- 有没有办法实现 getitem() (从数据集对象中获取特定批次)而不将整个对象放入内存?
- 如果第 1 项不可能,是否有任何解决方法?
提前致谢!
我了解到您担心内存中有完整的数据集。
别担心,tf.data.Dataset
API 非常高效,它不会将完整的数据集加载到内存中。
在内部,它只是创建一系列函数,当使用 model.fit()
调用时,它只会加载内存中的批次,而不是完整的数据集。
您可以在此 link 中阅读更多内容,我正在粘贴文档中的重要部分。
The tf.data.Dataset API supports writing descriptive and efficient
input pipelines. Dataset usage follows a common pattern:
Create a source dataset from your input data. Apply dataset
transformations to preprocess the data. Iterate over the dataset and
process the elements. Iteration happens in a streaming fashion, so the
full dataset does not need to fit into memory.
从最后一行可以看出,tf.data.Dataset
API 不会将完整的数据集加载到内存中,而是一次加载一批。
您必须执行以下操作来创建数据集的批次。
train_ds.batch(32)
这将创建大小为 32
的批次。您也可以使用预取来准备一批,然后再进行培训。这消除了模型在训练一批并等待另一批后空闲的瓶颈。
train_ds.batch(32).prefetch(1)
您还可以使用 cache
API 让您的数据管道更快。它将缓存您的数据集并使训练更快。
train_ds.batch(32).prefetch(1).cache()
所以简而言之,如果您担心将整个数据集加载到内存中,则不需要 generator
,tf.data.Dataset
API 会处理它。
希望我的回答对您有所帮助。
我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数 (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory) 加载图像,其中 return 是一个 tf.data.Dataset 对象:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical")
我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我正在使用 Google Colab 并且可以看到 RAM 使用量在训练期间增加第一个纪元)。 然后我尝试使用Keras Sequence,这是将部分数据加载到RAM的建议解决方案(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):
class DatasetGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return tf.data.experimental.cardinality(self.dataset).numpy()
def __getitem__(self, idx):
return list(self.dataset.as_numpy_iterator())[idx]
我用以下方法训练模型:
history = model.fit(DatasetGenerator(train_ds), ...)
问题是 getitem() 必须 return 一批有索引的数据。但是,我使用的 list() 函数必须将整个数据集放入 RAM 中,因此在 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。
我的问题:
- 有没有办法实现 getitem() (从数据集对象中获取特定批次)而不将整个对象放入内存?
- 如果第 1 项不可能,是否有任何解决方法?
提前致谢!
我了解到您担心内存中有完整的数据集。
别担心,tf.data.Dataset
API 非常高效,它不会将完整的数据集加载到内存中。
在内部,它只是创建一系列函数,当使用 model.fit()
调用时,它只会加载内存中的批次,而不是完整的数据集。
您可以在此 link 中阅读更多内容,我正在粘贴文档中的重要部分。
The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:
Create a source dataset from your input data. Apply dataset transformations to preprocess the data. Iterate over the dataset and process the elements. Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.
从最后一行可以看出,tf.data.Dataset
API 不会将完整的数据集加载到内存中,而是一次加载一批。
您必须执行以下操作来创建数据集的批次。
train_ds.batch(32)
这将创建大小为 32
的批次。您也可以使用预取来准备一批,然后再进行培训。这消除了模型在训练一批并等待另一批后空闲的瓶颈。
train_ds.batch(32).prefetch(1)
您还可以使用 cache
API 让您的数据管道更快。它将缓存您的数据集并使训练更快。
train_ds.batch(32).prefetch(1).cache()
所以简而言之,如果您担心将整个数据集加载到内存中,则不需要 generator
,tf.data.Dataset
API 会处理它。
希望我的回答对您有所帮助。