tensorflow 的 MonitoredTrainingSession 和切片输入生产者中的死锁

Deadlock in tensorflow's MonitoredTrainingSession and slice input producer

下面的代码死锁:

import tensorflow as tf

def train():
    """Stripped down and modified from cifar10.cifar10_train.train"""
    global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
    images = tf.constant([[1, 2, 3], [1, 2, 3]])
    labels = tf.constant([[1, 2, 3], [1, 2, 3]])
    images, labels = tf.train.slice_input_producer([images, labels],
                                                   shuffle=False)
    # input_var = tf.Variable([0, 0, 0])
    # images = input_var.assign(images) # TODO placeholder would work ?
    # input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
    input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
    tf_print = tf.Print(input_batch, [input_batch])
    with tf.train.MonitoredTrainingSession(
            hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(tf_print)

if __name__ == '__main__':
    train()

但是如果我注释掉 input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77]) 并取消注释程序继续打印的注释行:

I c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\logging_ops.cc:79] [1 2 3]

  1. 我不确定你的第一个问题,但我相信当你创建 MonitoredTrainingSession 时它会尝试初始化你的图形变量。但是在您的情况下,变量初始值之一依赖于隐藏在 tf.train.slice_input_producer 后面的出列操作。由于队列还没有启动,代码死锁等待队列入队。 在您评论的实现中,init_op 执行 运行,因此队列可以启动并使您的代码正常工作。

  2. 第二个问题的解释如下。 StopAtStepHook 依赖于正在更新的 global_step 张量,这在您的脚本中不是这种情况。这段代码 tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1)) 会起作用:基本上它将 tf.Print 操作和 global_step 增量组合在一起,所以每次 运行 tf_print 时,global_step 都会递增。

    import tensorflow as tf
    
    def train():
        """Stripped down and modified from cifar10.cifar10_train.train"""
        global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
        images = tf.constant([[1, 2, 3], [1, 2, 3]])
        labels = tf.constant([[1, 2, 3], [1, 2, 3]])
        images, labels = tf.train.slice_input_producer([images, labels], shuffle=False)
        input_var = tf.Variable([0, 0, 0])
        images = input_var.assign(images) # TODO placeholder would work ?
        input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
        tf_print = tf.group(tf.Print(input_batch, [input_batch]),
                            tf.assign_add(global_step, 1))
        with tf.train.MonitoredTrainingSession(
                hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(tf_print)
    
    if __name__ == '__main__':
        train()