在 TF2 中使用自定义训练循环时,如何保存所有变量(不仅是网络变量)以便能够恢复训练?
How to save all variables (not only net variables) to be able to resume training when using custom training loops in TF2?
我正在使用自定义训练循环在 TF2 中训练 Model
。我希望能够在给定时刻保存我的优化状态,以便稍后重新启动它。要保存的变量是模型参数,还有优化变量,以及一些其他变量。
在 TF1 中,这实际上甚至不是问题,因为 tf.train.Saver
默认会保存所有变量。
现在,在 TF2 中如何做到这一点?
根据指南,在 TF2 中,保存是通过 Keras 公开的功能完成的,使用特定的回调或 Model
方法。这两种方法 都可以 保存更多的不仅仅是网络参数,但为了能够实现这一点,需要使用 tf.Model.compile
编译模型,以便将所有内容捆绑在一起。但是,对于自定义训练循环,完全没有调用 compile
.
那么当一个人没有使用 compile
/fit
的正确路径时,如何保存我的所有变量以便能够恢复训练?
使用tf.train.Checkpoint
并将所有要保存的变量放入此函数。
tf.train.Checkpoint(model=model, optimizer=optimizer, [xx=xx])
更多详情请看这里tf.train.Checkpoint
我正在使用自定义训练循环在 TF2 中训练 Model
。我希望能够在给定时刻保存我的优化状态,以便稍后重新启动它。要保存的变量是模型参数,还有优化变量,以及一些其他变量。
在 TF1 中,这实际上甚至不是问题,因为 tf.train.Saver
默认会保存所有变量。
现在,在 TF2 中如何做到这一点?
根据指南,在 TF2 中,保存是通过 Keras 公开的功能完成的,使用特定的回调或 Model
方法。这两种方法 都可以 保存更多的不仅仅是网络参数,但为了能够实现这一点,需要使用 tf.Model.compile
编译模型,以便将所有内容捆绑在一起。但是,对于自定义训练循环,完全没有调用 compile
.
那么当一个人没有使用 compile
/fit
的正确路径时,如何保存我的所有变量以便能够恢复训练?
使用tf.train.Checkpoint
并将所有要保存的变量放入此函数。
tf.train.Checkpoint(model=model, optimizer=optimizer, [xx=xx])
更多详情请看这里tf.train.Checkpoint