tf.keras 如何保存 ModelCheckPoint 对象
tf.keras how to save ModelCheckPoint object
ModelCheckpoint 可用于根据特定的监控指标保存最佳模型。所以它显然有关于存储在其对象中的最佳指标的信息。例如,如果你在 google colab 上训练,你的实例可能会在没有警告的情况下被杀死,并且在长时间的训练后你会丢失这些信息。
我试图 pickle ModelCheckpoint 对象但得到:
TypeError: can't pickle _thread.lock objects
这样当我带回我的笔记本时,我可以重复使用这个相同的对象。有没有好的方法来做到这一点?您可以尝试重现:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
with open('chkpt_cb.pickle', 'w') as f:
pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)
我认为您可能误解了 ModelCheckpoint
对象的预期用途。这是一个callback that periodically gets called during training at a particular phase. The ModelCheckpoint callback in particular gets called after every epoch (if you keep the default period=1
) and saves your model to disk in the filename you specify to the filepath
argument. The model is saved in the same way described here。然后如果你想稍后加载那个模型,你可以做类似
的事情
from keras.models import load_model
model = load_model('my_model.h5')
SO 上的其他答案为从已保存的模型继续训练提供了很好的指导和示例,例如:。重要的是,保存的 H5 文件存储了继续训练所需的有关模型的所有信息。
正如 Keras documentation 中所建议的,您不应使用 pickle 来序列化您的模型。只需使用 'fit' 函数注册 ModelCheckpoint 回调:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
model.fit(x_train, y_train,
epochs=100,
steps_per_epoch=5000,
callbacks=[chkpt_cb])
您的模型将保存在一个以您的名字命名的 H5 文件中,其中会自动为您格式化纪元编号和损失值。例如,您保存的第 5 个 epoch 损失为 0.0023 的文件看起来像 model.05-.0023.h5
,并且由于您设置了 save_best_only=True
,只有当您的损失优于之前保存的模型时,模型才会被保存,因此您不要用一堆不需要的模型文件污染你的目录。
如果不对回调对象进行 pickle(由于线程问题且不可取),我可以 pickle 这个:
best = chkpt_cb.best
这存储了 callback 看到的最好的监控指标,它是一个浮点数,你可以在下次 pickle 和 reload,然后这样做:
chkpt_cb.best = best # if chkpt_cb is a brand new object you create when colab killed your session.
这是我自己的设置:
# All paths should be on Google Drive, I omitted it here for simplicity.
chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
if os.path.exists('chkpt_cb.best.pickle'):
with open('chkpt_cb.best.pickle', 'rb') as f:
best = pickle.load(f)
chkpt_cb.best = best
def save_chkpt_cb():
with open('chkpt_cb.best.pickle', 'wb') as f:
pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)
save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)
history = model.fit_generator(generator=train_data_gen,
validation_data=dev_data_gen,
epochs=5,
callbacks=[chkpt_cb, save_chkpt_cb_callback])
因此,即使您的 colab 会话被终止,您仍然可以检索最后的最佳指标并将其告知您的新实例,并照常继续训练。当您重新编译有状态优化器并且可能导致 loss/metric 回归并且不想在前几个时期保存这些模型时,这尤其有用。
ModelCheckpoint 可用于根据特定的监控指标保存最佳模型。所以它显然有关于存储在其对象中的最佳指标的信息。例如,如果你在 google colab 上训练,你的实例可能会在没有警告的情况下被杀死,并且在长时间的训练后你会丢失这些信息。
我试图 pickle ModelCheckpoint 对象但得到:
TypeError: can't pickle _thread.lock objects
这样当我带回我的笔记本时,我可以重复使用这个相同的对象。有没有好的方法来做到这一点?您可以尝试重现:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
with open('chkpt_cb.pickle', 'w') as f:
pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)
我认为您可能误解了 ModelCheckpoint
对象的预期用途。这是一个callback that periodically gets called during training at a particular phase. The ModelCheckpoint callback in particular gets called after every epoch (if you keep the default period=1
) and saves your model to disk in the filename you specify to the filepath
argument. The model is saved in the same way described here。然后如果你想稍后加载那个模型,你可以做类似
from keras.models import load_model
model = load_model('my_model.h5')
SO 上的其他答案为从已保存的模型继续训练提供了很好的指导和示例,例如:
正如 Keras documentation 中所建议的,您不应使用 pickle 来序列化您的模型。只需使用 'fit' 函数注册 ModelCheckpoint 回调:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
model.fit(x_train, y_train,
epochs=100,
steps_per_epoch=5000,
callbacks=[chkpt_cb])
您的模型将保存在一个以您的名字命名的 H5 文件中,其中会自动为您格式化纪元编号和损失值。例如,您保存的第 5 个 epoch 损失为 0.0023 的文件看起来像 model.05-.0023.h5
,并且由于您设置了 save_best_only=True
,只有当您的损失优于之前保存的模型时,模型才会被保存,因此您不要用一堆不需要的模型文件污染你的目录。
如果不对回调对象进行 pickle(由于线程问题且不可取),我可以 pickle 这个:
best = chkpt_cb.best
这存储了 callback 看到的最好的监控指标,它是一个浮点数,你可以在下次 pickle 和 reload,然后这样做:
chkpt_cb.best = best # if chkpt_cb is a brand new object you create when colab killed your session.
这是我自己的设置:
# All paths should be on Google Drive, I omitted it here for simplicity.
chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
if os.path.exists('chkpt_cb.best.pickle'):
with open('chkpt_cb.best.pickle', 'rb') as f:
best = pickle.load(f)
chkpt_cb.best = best
def save_chkpt_cb():
with open('chkpt_cb.best.pickle', 'wb') as f:
pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)
save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)
history = model.fit_generator(generator=train_data_gen,
validation_data=dev_data_gen,
epochs=5,
callbacks=[chkpt_cb, save_chkpt_cb_callback])
因此,即使您的 colab 会话被终止,您仍然可以检索最后的最佳指标并将其告知您的新实例,并照常继续训练。当您重新编译有状态优化器并且可能导致 loss/metric 回归并且不想在前几个时期保存这些模型时,这尤其有用。