调整从 TFRecord 文件加载的拼合图像的大小

Resizing flattened images loaded from TFRecord files

真的有必要在TFRecord 文件 上存储图像尺寸 信息吗?我目前正在处理由不同比例图像组成的数据集,并且没有为我处理的图像存储宽度、长度和通道数信息,现在我面临一个问题,需要 将它们调整回原始大小shape 在加载 tfrecords 之后执行 其他预处理管道 例如 数据增强 .

# Create dataset
records_path = DATA_DIR + 'TFRecords/train_0.tfrecords'
dataset = tf.data.TFRecordDataset(filenames=records_path)

#Parse dataset
parsed_dataset = dataset.map(parsing_fn)

# Get iterator
iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset) 
image,label = iterator.get_next()

# Get the numpy array from tensor, convert to uint8 and plot image from array
img_array = image.numpy()
img_array = img_array.astype(np.uint8)
plt.imshow(img_array)
plt.show()

Output: TypeError: Invalid dimensions for image data

在转换为 uint8 之间我应该将图像大小调整回原始形状?如果是这样,我不存储维度信息怎么办?

下面的管道演示了一个转换示例,我想将其应用于从 tfrecord 读取的图像,但我相信这些 keras 增强方法 需要具有 定义尺寸的适当调整大小的数组才能运行。(我不一定需要打印图像)。

def brightness(brightness_range, image):
    img = tf.keras.preprocessing.image.load_img(image)
    data = tf.keras.preprocessing.image.array_to_img(img)
    samples = expand_dims(data,0)
    print(samples.shape)
    datagen = tf.keras.preprocessing.image.ImageDataGenerator(brightness_range=brightness_range) 
    iterator = datagen.flow(samples,batch_size=1) 
    for i in range(9):
        pyplot.subplot(330+1+i)
        batch = iterator.next()
        image = batch[0].astype('uint8')
        pyplot.imshow(image)  
    pyplot.show()
brightness([0.2,1.0],DATA_DIR+"183350/5c3e30f1706244e9f199d5a0c5a5ec00d1cbf473.jpg")

写入和读取 TFRecord 格式的辅助函数

正在转换为 tfrecord:

def convert(image_paths, labels, out_path):
    # Args:
    # image_paths   List of file-paths for the images.
    # labels        Class-labels for the images.
    # out_path      File-path for the TFRecords output file.
    
    print("Converting: " + out_path)
    
    # Number of images. Used when printing the progress.
    num_images = len(image_paths)
    
    # Open a TFRecordWriter for the output-file.
    with tf.python_io.TFRecordWriter(out_path) as writer:
        
        # Iterate over all the image-paths and class-labels.
        for i in range(num_images):
          # Print the percentage-progress.
          print_progress(count=i, total=num_images-1)
          
          # Load the image-file using matplotlib's imread function.
          path = image_paths[i]
          img = imread(path)
          path = path.split('/')

          # Convert the image to raw bytes.
          img_bytes = img.tostring()

          # Get the label index  
          label = int(path[4])

          # Create a dict with the data we want to save in the
          # TFRecords file. You can add more relevant data here.
          data = \
              {
                  'image': wrap_bytes(img_bytes),
                  'label': wrap_int64(label)
              }

          # Wrap the data as TensorFlow Features.
          feature = tf.train.Features(feature=data)

          # Wrap again as a TensorFlow Example.
          example = tf.train.Example(features=feature)

          # Serialize the data.
          serialized = example.SerializeToString()
            
          # Write the serialized data to the TFRecords file.
          writer.write(serialized)

解析函数

def parsing_fn(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.
    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_raw(image_raw, tf.uint8)
    
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)

    # Get the label associated with the image.
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

转换时需要使用tf.io.encode_jpeg,解析时需要使用tf.io.decode_jpeg。然后,当您解码 jpeg 时,它将保留尺寸

更具体地说,当编码这样的东西时

image_bytes = tf.io.gfile.GFile(path, 'rb').read()
image = tf.io.decode_jpeg(img_bytes, channels=3)
image_bytes = tf.io.encode_jpeg(tf.cast(image, tf.uint8))

并且在解析期间

image = tf.io.decode_jpeg(image_raw)