如何设置 Keras Autoencoder 和 reshape() 以使用 ImageDataGenerator 处理 224 x 224 jpg 图像?

How to setup Keras Autoencoder and reshape() to process 224 x 224 jpg images using ImageDataGenerator?

我正在尝试将 Tensorflow Keras autoencoder implementation 应用到我自己的属于 40 类 的 224 x 224 图像数据集,我的设置如下:

我正在使用 ImageDataGenerator 创建训练集、验证集和测试集;训练集大小为 13,988;测试和验证大小均为 3,000。

但无论我如何设置这些参数,我都会收到如下错误:

InvalidArgumentError: Can not squeeze dim[3], expected a dimension of 1, got 3

或形状不匹配错误。

我想我没有正确设置我的自动编码器。有人可以发现问题并告诉我如何修复我的代码以使其正常工作吗?

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import os

from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model
from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator()

latent_dim =  64

class Autoencoder(Model):
  def __init__(self, latent_dim):
    super(Autoencoder, self).__init__()
    self.latent_dim = latent_dim   
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(latent_dim, activation='relu'),
    ])
    self.decoder = tf.keras.Sequential([
      layers.Dense(50176, activation='sigmoid'),
      layers.Reshape((224, 224))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded
  
autoencoder = Autoencoder(latent_dim) 

autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())

train_datagen = datagen.flow_from_directory(
    'ae_data/train',
    target_size=(224, 224),
    batch_size=64,
    class_mode='input',
    color_mode='rgb')

validation_datagen = datagen.flow_from_directory(
    'ae_data/validation',
    target_size=(224, 224),
    batch_size=64,
    class_mode='input',
    color_mode='rgb')

test_datagen = datagen.flow_from_directory(
    'ae_data/test',
    target_size=(224, 224),
    batch_size=64,
    class_mode='input',
    color_mode='rgb')



autoencoder.fit(train_datagen, steps_per_epoch=218, validation_data=test_datagen, epochs=1,
               shuffle=False, validation_steps=8)

本教程使用时尚 MNIST 灰度图像。您可能正在使用 rgb 图像。

由于错误指出它不能将 3 个值压缩为 1,因此您的图像大小应为 224 x 224 x 3。三维表示 rgb 的 3 个值。

现在,如果颜色不重要,您可以先将图像预处理为灰度。