TF LSTM:从训练会话中保存状态以便稍后进行预测会话
TF LSTM: Save State from training session for prediction session later
我正在尝试从训练中保存最新的 LSTM 状态,以便稍后在预测阶段重复使用。我遇到的问题是,在 TF LSTM 模型中,状态通过占位符和 numpy 数组的组合从一个训练迭代传递到下一个迭代——默认情况下,这两者似乎都不包含在图表中已保存。
为了解决这个问题,我创建了一个专用的 TF 变量来保存最新版本的状态,以便将其添加到会话图中,如下所示:
# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
这似乎可以很好地将 savedState
变量添加到保存的会话图中,并且以后可以很容易地用会话的其余部分恢复。
但问题是,我设法在恢复后的会话中实际使用该变量的唯一方法是,如果我在恢复会话后初始化会话中的所有变量(这似乎重置了所有经过训练的变量,包括 weights/biases/etc.!)。如果我先初始化变量然后恢复会话(这在保留训练有素的变量方面效果很好),那么我会收到一个错误,提示我正在尝试访问一个未初始化的变量。
我知道有一种方法可以初始化特定的单个变量(我最初保存它时正在使用它)但问题是当我们恢复它们时,我们通过名称将它们称为字符串,我们不会只传递变量本身?!
# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')
完成这项工作的正确方法是什么?作为一种解决方法,我目前将状态作为 numpy 数组保存到 CSV,然后以相同的方式恢复它。它工作正常,但显然不是最干净的解决方案,因为 saving/restoring TF 会话的每个其他方面都工作得很好。
感谢任何建议!
**编辑:
这是运行良好的代码,如以下接受的答案中所述:
# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])
我还没有测试过这种方法是否会无意中初始化保存的会话中的任何其他变量,但不明白为什么会这样,因为我们只 运行 特定的变量。
问题是在构建 Saver
之后创建新的 tf.Variable
意味着 Saver
不知道新变量。它仍然保存在元图中,但没有保存在检查点中:
import tensorflow as tf
with tf.Graph().as_default():
var_a = tf.get_variable("a", shape=[])
saver = tf.train.Saver()
var_b = tf.get_variable("b", shape=[])
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
initializer = tf.global_variables_initializer()
with tf.Session() as session:
session.run([initializer])
saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
with tf.Session() as session:
new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!
我已经用 Saver
知道的变量对上面问题的快速再现进行了注释。
现在,解决方案相对简单。我建议在 Saver
之前创建 Variable
,然后使用 tf.assign 更新它的值(确保你 运行 op tf.assign
返回)。分配的值将保存在检查点中,并像其他变量一样恢复。
当 None
传递给其 var_list
构造函数参数时,Saver
可以更好地处理这种情况(即它可以自动获取新变量)。为此,请随时 open a feature request on Github。
我正在尝试从训练中保存最新的 LSTM 状态,以便稍后在预测阶段重复使用。我遇到的问题是,在 TF LSTM 模型中,状态通过占位符和 numpy 数组的组合从一个训练迭代传递到下一个迭代——默认情况下,这两者似乎都不包含在图表中已保存。
为了解决这个问题,我创建了一个专用的 TF 变量来保存最新版本的状态,以便将其添加到会话图中,如下所示:
# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
这似乎可以很好地将 savedState
变量添加到保存的会话图中,并且以后可以很容易地用会话的其余部分恢复。
但问题是,我设法在恢复后的会话中实际使用该变量的唯一方法是,如果我在恢复会话后初始化会话中的所有变量(这似乎重置了所有经过训练的变量,包括 weights/biases/etc.!)。如果我先初始化变量然后恢复会话(这在保留训练有素的变量方面效果很好),那么我会收到一个错误,提示我正在尝试访问一个未初始化的变量。
我知道有一种方法可以初始化特定的单个变量(我最初保存它时正在使用它)但问题是当我们恢复它们时,我们通过名称将它们称为字符串,我们不会只传递变量本身?!
# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')
完成这项工作的正确方法是什么?作为一种解决方法,我目前将状态作为 numpy 数组保存到 CSV,然后以相同的方式恢复它。它工作正常,但显然不是最干净的解决方案,因为 saving/restoring TF 会话的每个其他方面都工作得很好。
感谢任何建议!
**编辑: 这是运行良好的代码,如以下接受的答案中所述:
# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])
我还没有测试过这种方法是否会无意中初始化保存的会话中的任何其他变量,但不明白为什么会这样,因为我们只 运行 特定的变量。
问题是在构建 Saver
之后创建新的 tf.Variable
意味着 Saver
不知道新变量。它仍然保存在元图中,但没有保存在检查点中:
import tensorflow as tf
with tf.Graph().as_default():
var_a = tf.get_variable("a", shape=[])
saver = tf.train.Saver()
var_b = tf.get_variable("b", shape=[])
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
initializer = tf.global_variables_initializer()
with tf.Session() as session:
session.run([initializer])
saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
with tf.Session() as session:
new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!
我已经用 Saver
知道的变量对上面问题的快速再现进行了注释。
现在,解决方案相对简单。我建议在 Saver
之前创建 Variable
,然后使用 tf.assign 更新它的值(确保你 运行 op tf.assign
返回)。分配的值将保存在检查点中,并像其他变量一样恢复。
当 None
传递给其 var_list
构造函数参数时,Saver
可以更好地处理这种情况(即它可以自动获取新变量)。为此,请随时 open a feature request on Github。