如何从 ImageDataGenerator() 查看样本

How to view samples from ImageDataGenerator()

我目前正在学习使用 cifar10 图像的教程。我写了一些完整的工作代码,其中 model.fit(x_train, y_train) 行 x_train 作为维度 50000x32x32x3 和 dtype "uint8" 的 numpy 数组。 IE。它包含 50000 个 32x32 像素的彩色图像。我可以通过调用 imshow() 显示这些图像的样本 - 它看起来和工作正常。

但现在在本教程的下一部分中,它建议如果我们使用 ImageDataGenerator() 创建训练图像的多个变形(旋转、缩放、倾斜等)版本,模型将泛化得更好。我想通过显示在此过程中产生的一些变形图像来更好地理解 ImageDataGenerator()。查看 the documentation,它给出了以下示例:

# here's a more "manual" example
for e in range(epochs):
    print('Epoch', e)
    batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
        model.fit(x_batch, y_batch)
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break

我当前的代码(没有变形)使用行 model.fit(x_train, y_train) 训练模型,所以查看示例中的行 model.fit(x_batch, y_batch) 我假设 x_batch 必须是一个集合当前 x_train 图像的 32 个不同变形版本。我试着写了一些代码,这样我就可以像这样实际显示 32 张图像:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def print_array_info(v):
    print("{} is of type {} with shape {} and dtype {}".format(v,
                                                           eval("type({})".format(v)),
                                                           eval("{}.shape".format(v)),
                                                           eval("{}.dtype".format(v))
                                                           ))

def show_samples(array_of_images):
    n = array_of_images.shape[0]
    total_rows = 1+int((n-1)/5)
    total_columns = 5
    fig = plt.figure()
    gridspec_array = fig.add_gridspec(total_rows, total_columns)

    for i, img in enumerate(array_of_images):
        row = int(i/5)
        col = i % 5
        ax = fig.add_subplot(gridspec_array[row, col])
        ax.imshow(img)

    plt.show()


cifar_data = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar_data.load_data()

data_generator = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=20)

print_array_info("x_train")

batches = 0
batch_size=32

for x_batch, y_batch in data_generator.flow(x_train, y_train, batch_size=batch_size):
    print_array_info("x_batch")
    batches += 1
    if batches >= len(x_train) / 32:
        break
    show_samples(x_batch[:batch_size])

我认为第一次通过循环时,我会看到 x_train 中第零张图像的 32 个不同变形版本。但是当我 运行 它产生几乎空白的图像时 - 我说几乎是因为其中一个或多个可能包含一些看起来像垃圾的像素。我预计 x_batch 的大小为 32x32x32x3,即 32 张大小为 32x32 像素的彩色图像和 3 种颜色的集合,这确实看起来是真的,但 dtype 是 float32 这让我感到困惑 - 我认为变形过程不会改变 dtype。

我的代码中有错误还是我误解了文档?

x_batch 是一个 float 类型的数组,因为对数据集应用了随机旋转。

当传递给 matplotlib.axes.Axes.imshow 的图像数据 (x_batch) 是 float 类型时,rgb 值必须在 0-[=16= 的范围内].将 x_batch 乘以 1/225 得到单位分数表示。

for x_batch, y_batch in data_generator.flow(x_train, y_train, batch_size=batch_size):
    print_array_info("x_batch")
    batches += 1
    if batches >= len(x_train) / batch_size:
        break
    show_samples(x_batch * 1 / 255)