batch_size 在 tf model.fit() 与 batch_size 在 tf.data.Dataset
batch_size in tf model.fit() vs. batch_size in tf.data.Dataset
我有一个适合主机内存的大型数据集。但是,当我使用 tf.keras 训练模型时,它会产生 GPU 内存不足的问题。然后我查看 tf.data.Dataset 并想使用它的 batch() 方法对训练数据集进行批处理,以便它可以在 GPU 中执行 model.fit()。根据其文档,示例如下:
train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
dataset.from_tensor_slices().batch()中的BATCH_SIZE和tf.kerasmodelt.fit()中的batch_size一样吗?
如何选择BATCH_SIZE才能让GPU有足够的数据高效运行又不至于内存溢出?
在这种情况下,您不需要在 model.fit()
中传递 batch_size
参数。它将自动使用您在 tf.data.Dataset().batch()
.
中使用的 BATCH_SIZE
至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果你看到 OOM 错误,你应该减少它直到你没有 OOM(通常以这种方式 32 --> 16 --> 8 ...)。
在你的情况下,我会从 2 的 batch_size 开始,然后将其增加 2 的幂,然后检查我是否仍然出现 OOM。
如果使用tf.data.Dataset().batch()
方法,则不需要提供batch_size
参数。
事实上,就连官方documentation也是这么说的:
batch_size : Integer or None. Number of samples per gradient update.
If unspecified, batch_size will default to 32. Do not specify the
batch_size if your data is in the form of datasets, generators, or
keras.utils.Sequence instances (since they generate batches).
我有一个适合主机内存的大型数据集。但是,当我使用 tf.keras 训练模型时,它会产生 GPU 内存不足的问题。然后我查看 tf.data.Dataset 并想使用它的 batch() 方法对训练数据集进行批处理,以便它可以在 GPU 中执行 model.fit()。根据其文档,示例如下:
train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
dataset.from_tensor_slices().batch()中的BATCH_SIZE和tf.kerasmodelt.fit()中的batch_size一样吗?
如何选择BATCH_SIZE才能让GPU有足够的数据高效运行又不至于内存溢出?
在这种情况下,您不需要在 model.fit()
中传递 batch_size
参数。它将自动使用您在 tf.data.Dataset().batch()
.
至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果你看到 OOM 错误,你应该减少它直到你没有 OOM(通常以这种方式 32 --> 16 --> 8 ...)。
在你的情况下,我会从 2 的 batch_size 开始,然后将其增加 2 的幂,然后检查我是否仍然出现 OOM。
如果使用tf.data.Dataset().batch()
方法,则不需要提供batch_size
参数。
事实上,就连官方documentation也是这么说的:
batch_size : Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).