Tensorflow shuffle_batch() 在纪元结束时阻塞
Tensor Flow shuffle_batch() blocks at end of epoch
我正在使用 tf.train.shuffle_batch() 创建批量输入图像。它包括一个 min_after_dequeue 参数,用于确保内部队列中有指定数量的元素,如果没有则阻止其他所有元素。
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=num_preprocess_threads,
capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
min_after_dequeue=FLAGS.min_queue_size)
在一个纪元结束时,当我进行评估时(我确信这在训练中也是如此,但我没有测试过),一切都阻塞了。我发现在同一时刻,内部洗牌批处理队列将剩下少于 min_after_dequeue 个元素。此时在程序中,理想情况下我只想将剩余的元素出队,但我不确定如何。
显然,当您知道没有更多元素可以使用 .close() 方法排队时,可以关闭 TF 队列中的这种类型的阻塞。但是,由于底层队列隐藏在函数内部,我该如何调用该方法?
你是正确的 运行 当队列中的元素少于 min_after_dequeue
时,RandomShuffleQueue.close()
操作将阻止出队线程阻塞。
tf.train.shuffle_batch()
function creates a tf.train.QueueRunner
that performs operations on the queue in a background thread. If you start it as follows, passing a tf.train.Coordinator
, you will be able to close the queue cleanly (based on the example here):
sess = tf.Session()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord=coord)
while not coord.should_stop():
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(threads)
这是我最终开始工作的代码,尽管有一堆警告说我排队的元素被取消了。
lv = tf.constant(label_list)
label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
# num_epochs = 1
# else:
# num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])
reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()
images, label_batch = tf.train.shuffle_batch(
[result.uint8image, result.label],
batch_size=FLAGS.batch_size,
num_threads=num_preprocess_threads,
capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
min_after_dequeue=FLAGS.min_queue_size)
#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size
sess.run(label_enqueue)
step = 0
while step < num_iter and not coord.should_stop():
end_epoch = False
if step > 0:
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
#check if not enough elements in queue
size = qr._queue.size().eval()
if size - FLAGS.batch_size < FLAGS.min_queue_size:
end_epoch = True
if end_epoch:
#enqueue more so that we can finish
sess.run(label_enqueue)
#actually run step
predictions = sess.run([top_k_op])
有一个可选参数allow_smaller_final_batch
"allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue."
我正在使用 tf.train.shuffle_batch() 创建批量输入图像。它包括一个 min_after_dequeue 参数,用于确保内部队列中有指定数量的元素,如果没有则阻止其他所有元素。
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=num_preprocess_threads,
capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
min_after_dequeue=FLAGS.min_queue_size)
在一个纪元结束时,当我进行评估时(我确信这在训练中也是如此,但我没有测试过),一切都阻塞了。我发现在同一时刻,内部洗牌批处理队列将剩下少于 min_after_dequeue 个元素。此时在程序中,理想情况下我只想将剩余的元素出队,但我不确定如何。
显然,当您知道没有更多元素可以使用 .close() 方法排队时,可以关闭 TF 队列中的这种类型的阻塞。但是,由于底层队列隐藏在函数内部,我该如何调用该方法?
你是正确的 运行 当队列中的元素少于 min_after_dequeue
时,RandomShuffleQueue.close()
操作将阻止出队线程阻塞。
tf.train.shuffle_batch()
function creates a tf.train.QueueRunner
that performs operations on the queue in a background thread. If you start it as follows, passing a tf.train.Coordinator
, you will be able to close the queue cleanly (based on the example here):
sess = tf.Session()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord=coord)
while not coord.should_stop():
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(threads)
这是我最终开始工作的代码,尽管有一堆警告说我排队的元素被取消了。
lv = tf.constant(label_list)
label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
# num_epochs = 1
# else:
# num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])
reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()
images, label_batch = tf.train.shuffle_batch(
[result.uint8image, result.label],
batch_size=FLAGS.batch_size,
num_threads=num_preprocess_threads,
capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
min_after_dequeue=FLAGS.min_queue_size)
#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size
sess.run(label_enqueue)
step = 0
while step < num_iter and not coord.should_stop():
end_epoch = False
if step > 0:
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
#check if not enough elements in queue
size = qr._queue.size().eval()
if size - FLAGS.batch_size < FLAGS.min_queue_size:
end_epoch = True
if end_epoch:
#enqueue more so that we can finish
sess.run(label_enqueue)
#actually run step
predictions = sess.run([top_k_op])
有一个可选参数allow_smaller_final_batch
"allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue."