将 model.fit_generator 转换为 model.fit

Convert model.fit_generator to model.fit

我有以下代码,

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

现在model.fit_generator定义如下:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

现在 model.fit_generator 已弃用,在这种情况下,将 model.fit_generator 更改为 model.fit 的正确方法是什么?

您只需将 model.fit_generator() 更改为 model.fit()

从 TensorFlow 2.1 开始,model.fit() 也接受生成器作为输入。就这么简单。

来自TensorFlow的官方文档:

Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.

去掉'generator='.

旧培训:
model.fit_generator(generator=train_generator,
    steps_per_epoch=2048//36, epochs=10,
    validation_data=validation_generator, validation_steps=832//16)
新培训:
model.fit(train_generator,
    steps_per_epoch=2048 // 128, epochs=10,
    validation_data=validation_generator, validation_steps=832//16)