Tensorflow shuffle_batch() 在纪元结束时阻塞

Tensor Flow shuffle_batch() blocks at end of epoch

我正在使用 tf.train.shuffle_batch() 创建批量输入图像。它包括一个 min_after_dequeue 参数,用于确保内部队列中有指定数量的元素,如果没有则阻止其他所有元素。

images, label_batch = tf.train.shuffle_batch(
  [image, label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

在一个纪元结束时,当我进行评估时(我确信这在训练中也是如此,但我没有测试过),一切都阻塞了。我发现在同一时刻,内部洗牌批处理队列将剩下少于 min_after_dequeue 个元素。此时在程序中,理想情况下我只想将剩余的元素出队,但我不确定如何。

显然,当您知道没有更多元素可以使用 .close() 方法排队时,可以关闭 TF 队列中的这种类型的阻塞。但是,由于底层队列隐藏在函数内部,我该如何调用该方法?

你是正确的 运行 当队列中的元素少于 min_after_dequeue 时,RandomShuffleQueue.close() 操作将阻止出队线程阻塞。

tf.train.shuffle_batch() function creates a tf.train.QueueRunner that performs operations on the queue in a background thread. If you start it as follows, passing a tf.train.Coordinator, you will be able to close the queue cleanly (based on the example here):

sess = tf.Session()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord=coord)

while not coord.should_stop():
  sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(threads)

这是我最终开始工作的代码,尽管有一堆警告说我排队的元素被取消了。

lv = tf.constant(label_list)

label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
    # num_epochs = 1
# else:
    # num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])


reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()

images, label_batch = tf.train.shuffle_batch(
  [result.uint8image, result.label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
  threads = []
  for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                     start=True))

  num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
  true_count = 0  # Counts the number of correct predictions.
  total_sample_count = num_iter * FLAGS.batch_size

  sess.run(label_enqueue)
  step = 0
  while step < num_iter and not coord.should_stop():
    end_epoch = False
    if step > 0:
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            #check if not enough elements in queue
            size = qr._queue.size().eval()
            if size - FLAGS.batch_size < FLAGS.min_queue_size:
                end_epoch = True
    if end_epoch:
        #enqueue more so that we can finish
        sess.run(label_enqueue)
    #actually run step
    predictions = sess.run([top_k_op])

有一个可选参数allow_smaller_final_batch

"allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue."