tensorflow 的 MonitoredTrainingSession 和切片输入生产者中的死锁
Deadlock in tensorflow's MonitoredTrainingSession and slice input producer
下面的代码死锁:
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels],
shuffle=False)
# input_var = tf.Variable([0, 0, 0])
# images = input_var.assign(images) # TODO placeholder would work ?
# input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
tf_print = tf.Print(input_batch, [input_batch])
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()
但是如果我注释掉 input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
并取消注释程序继续打印的注释行:
I c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\logging_ops.cc:79] [1 2 3]
- 为什么会死锁?像我一样使用额外的变量来解决这个问题是正确的方法吗?或者我应该以某种方式使用占位符?
- 我错过了什么,它在 3 个步骤后没有终止?
我不确定你的第一个问题,但我相信当你创建 MonitoredTrainingSession 时它会尝试初始化你的图形变量。但是在您的情况下,变量初始值之一依赖于隐藏在 tf.train.slice_input_producer
后面的出列操作。由于队列还没有启动,代码死锁等待队列入队。
在您评论的实现中,init_op
执行 运行,因此队列可以启动并使您的代码正常工作。
第二个问题的解释如下。
StopAtStepHook
依赖于正在更新的 global_step
张量,这在您的脚本中不是这种情况。这段代码
tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1))
会起作用:基本上它将 tf.Print
操作和 global_step
增量组合在一起,所以每次 运行 tf_print
时,global_step
都会递增。
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels], shuffle=False)
input_var = tf.Variable([0, 0, 0])
images = input_var.assign(images) # TODO placeholder would work ?
input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
tf_print = tf.group(tf.Print(input_batch, [input_batch]),
tf.assign_add(global_step, 1))
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()
下面的代码死锁:
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels],
shuffle=False)
# input_var = tf.Variable([0, 0, 0])
# images = input_var.assign(images) # TODO placeholder would work ?
# input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
tf_print = tf.Print(input_batch, [input_batch])
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()
但是如果我注释掉 input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
并取消注释程序继续打印的注释行:
I c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\logging_ops.cc:79] [1 2 3]
- 为什么会死锁?像我一样使用额外的变量来解决这个问题是正确的方法吗?或者我应该以某种方式使用占位符?
- 我错过了什么,它在 3 个步骤后没有终止?
我不确定你的第一个问题,但我相信当你创建 MonitoredTrainingSession 时它会尝试初始化你的图形变量。但是在您的情况下,变量初始值之一依赖于隐藏在
tf.train.slice_input_producer
后面的出列操作。由于队列还没有启动,代码死锁等待队列入队。 在您评论的实现中,init_op
执行 运行,因此队列可以启动并使您的代码正常工作。第二个问题的解释如下。
StopAtStepHook
依赖于正在更新的global_step
张量,这在您的脚本中不是这种情况。这段代码tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1))
会起作用:基本上它将tf.Print
操作和global_step
增量组合在一起,所以每次 运行tf_print
时,global_step
都会递增。import tensorflow as tf def train(): """Stripped down and modified from cifar10.cifar10_train.train""" global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook images = tf.constant([[1, 2, 3], [1, 2, 3]]) labels = tf.constant([[1, 2, 3], [1, 2, 3]]) images, labels = tf.train.slice_input_producer([images, labels], shuffle=False) input_var = tf.Variable([0, 0, 0]) images = input_var.assign(images) # TODO placeholder would work ? input_batch = tf.scatter_nd_update(images, [[1, 2]], [77]) tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step, 1)) with tf.train.MonitoredTrainingSession( hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(tf_print) if __name__ == '__main__': train()