batch_size 在 tf model.fit() 与 batch_size 在 tf.data.Dataset

batch_size in tf model.fit() vs. batch_size in tf.data.Dataset

我有一个适合主机内存的大型数据集。但是,当我使用 tf.keras 训练模型时,它会产生 GPU 内存不足的问题。然后我查看 tf.data.Dataset 并想使用它的 batch() 方法对训练数据集进行批处理,以便它可以在 GPU 中执行 model.fit()。根据其文档,示例如下:

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

dataset.from_tensor_slices().batch()中的BATCH_SIZE和tf.kerasmodelt.fit()中的batch_size一样吗?

如何选择BATCH_SIZE才能让GPU有足够的数据高效运行又不至于内存溢出?

在这种情况下,您不需要在 model.fit() 中传递 batch_size 参数。它将自动使用您在 tf.data.Dataset().batch().

中使用的 BATCH_SIZE

至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果你看到 OOM 错误,你应该减少它直到你没有 OOM(通常以这种方式 32 --> 16 --> 8 ...)。

在你的情况下,我会从 2 的 batch_size 开始,然后将其增加 2 的幂,然后检查我是否仍然出现 OOM。

如果使用tf.data.Dataset().batch()方法,则不需要提供batch_size参数。

事实上,就连官方documentation也是这么说的:

batch_size : Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).