Tensorflow 输入管道用于生成批处理的问题

Tensorflow Input Pipeline for issues with generating batch

我目前正在尝试使用 tensorflow 队列编写 Tensorflow 数据输入管道。我的数据由 jpg 图像、三个通道 (RGB) 组成,大小为 128x128 像素。

我当前的问题是 运行 我的 image_batch 操作,因为操作一直停止,我不确定为什么。

下面是我构建输入管道的代码。

我正在使用三个主要功能:

  1. read_my_file_format 接受一个 filename_queue 并尝试加载文件并调整它的大小
  2. tensorflow_queue 获取对象列表并生成一个 tensorflow FIFO 队列。然后将队列添加到queuerunner并添加到tf.train.add_queue_runner

  3. shuffle_queue_batch 用于 return 获取一批图像和标签的操作。

下面是我的代码。

def read_my_file_format(filename_queue):
   reader = tf.WholeFileReader()
   filename, image_string = reader.read(filename_queue)
   image = tf.image.decode_jpeg(image_string, channels=3)
   image = tf.image.resize_images(image, size=[256, 256])
   return image

def tensorflow_queue(lst, dtype, capacity=32):
    tensor = tf.convert_to_tensor(lst, dtype=dtype)
    fq = tf.FIFOQueue(capacity=capacity, dtypes=dtype, shapes=(()))
    fq_enqueue_op = fq.enqueue_many([tensor])
    tf.train.add_queue_runner(tf.train.QueueRunner(fq, [fq_enqueue_op]*1))
    return fq

def shuffle_queue_batch(image, label, batch_size, capacity=32, min_after_dequeue=10, threads=1):
    tensor_list = [image, label]
    dtypes = [tf.float32, tf.int32]
    shapes = [image.get_shape(), label.get_shape()]
    rand_shuff_queue = tf.RandomShuffleQueue(
                                capacity=capacity,
                                min_after_dequeue=min_after_dequeue,
                                dtypes=dtypes,
                                shapes=shapes
                                )
    rand_shuff_enqueue_op = rand_shuff_queue.enqueue(tensor_list)
    tf.train.add_queue_runner(tf.train.QueueRunner(rand_shuff_queue, [rand_shuff_enqueue_op] * threads))

    image_batch, label_batch = rand_shuff_queue.dequeue_many(batch_size)
    return image_batch, label_batch

def input_pipeline(filenames, classes, min_after_dequeue=10):
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    classes_queue = tensorflow_queue(classes, tf.int32)
    image = read_my_file_format(filename_queue)
    label = classes_queue.dequeue()
    image_batch, label_batch = shuffle_queue_batch(image, label, BATCH_SIZE, min_after_dequeue=min_after_dequeue)

    return image_batch, label_batch


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # get_image_data returns:
    #    filenames is a list of strings of the filenames
    #    classes is a list of ints
    #    datasize = number of images in dataset
    filenames, classes, datasize = get_image_data()


    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    image_batch, label_batch = input_pipeline(filenames, classes)

    print('Starting training')
    for ep in range(NUM_EPOCHS):
        total_loss = 0
        for _ in range(datasize // BATCH_SIZE * BATCH_SIZE):
            print('fetching batch')
            x_batch = sess.run([image_batch])
            print('x batch')
            y_batch = sess.run([label_batch])
            x_batch, y_batch = sess.run([image_batch, label_batch])

提前谢谢你。

我强烈建议您将输入管道从 tf.train 队列切换到 tf.data。队列输入管道效率低下且难以维护。

您的代码大部分是正确的。只需稍作更改即可使代码正常工作。您的代码不起作用的原因是因为您在声明队列之前启动了队列运行器。如果您查看 start_queue_runners 的 return 值,那么您会发现该列表是空的。

话虽如此,Alexandre 的建议还是不错的。 tf.Data 是获得高性能输入管道的方法。此外,queuerunners 与新的 TF Eager 机制不兼容。

相关代码如下:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # get_image_data returns:
    #    filenames is a list of strings of the filenames
    #    classes is a list of ints
    #    datasize = number of images in dataset
    filenames, classes, datasize = get_image_data()

    image_batch, label_batch = input_pipeline(filenames, classes)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    #image_batch, label_batch = input_pipeline(filenames, classes)

    print('Starting training')
    for ep in range(NUM_EPOCHS):
        total_loss = 0
        for _ in range(datasize // BATCH_SIZE * BATCH_SIZE):
            print('fetching batch')
            x_batch = sess.run([image_batch])
            print('x batch')
            y_batch = sess.run([label_batch])
            x_batch, y_batch = sess.run([image_batch, label_batch])