每 10 个时期保存模型 tensorflow.keras v2
Save model every 10 epochs tensorflow.keras v2
我正在使用在 tensorflow v2 中定义为子模块的 keras。我正在使用 fit_generator()
方法训练我的模型。我想每 10 个时期保存一次我的模型。我怎样才能做到这一点?
在 Keras 中(不是作为 tf 的子模块),我可以给出 ModelCheckpoint(model_savepath,period=10)
。但在 tf v2 中,他们将其更改为 ModelCheckpoint(model_savepath, save_freq)
,其中 save_freq
可以是 'epoch'
,在这种情况下,每个时期都会保存模型。如果 save_freq
是整数,则在处理完这么多样本后保存模型。但我希望它在 10 个纪元之后。我怎样才能做到这一点?
使用 tf.keras.callbacks.ModelCheckpoint
使用 save_freq='epoch'
并传递一个额外的参数 period=10
.
虽然 official docs 中没有记录,但这就是执行此操作的方法(注意它记录了您可以通过 period
,只是没有解释它的作用)。
明确计算每个时期的批次数量对我有用。
BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10
# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))
# Train the model with the new callback
model.fit(train_images,
train_labels,
batch_size=BATCH_SIZE,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)
我正在使用在 tensorflow v2 中定义为子模块的 keras。我正在使用 fit_generator()
方法训练我的模型。我想每 10 个时期保存一次我的模型。我怎样才能做到这一点?
在 Keras 中(不是作为 tf 的子模块),我可以给出 ModelCheckpoint(model_savepath,period=10)
。但在 tf v2 中,他们将其更改为 ModelCheckpoint(model_savepath, save_freq)
,其中 save_freq
可以是 'epoch'
,在这种情况下,每个时期都会保存模型。如果 save_freq
是整数,则在处理完这么多样本后保存模型。但我希望它在 10 个纪元之后。我怎样才能做到这一点?
使用 tf.keras.callbacks.ModelCheckpoint
使用 save_freq='epoch'
并传递一个额外的参数 period=10
.
虽然 official docs 中没有记录,但这就是执行此操作的方法(注意它记录了您可以通过 period
,只是没有解释它的作用)。
明确计算每个时期的批次数量对我有用。
BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10
# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))
# Train the model with the new callback
model.fit(train_images,
train_labels,
batch_size=BATCH_SIZE,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)