将 model.fit_generator 转换为 model.fit
Convert model.fit_generator to model.fit
我有以下代码,
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
现在model.fit_generator
定义如下:
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
现在 model.fit_generator
已弃用,在这种情况下,将 model.fit_generator
更改为 model.fit
的正确方法是什么?
您只需将 model.fit_generator()
更改为 model.fit()
。
从 TensorFlow 2.1 开始,model.fit()
也接受生成器作为输入。就这么简单。
来自TensorFlow的官方文档:
Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future
version. Instructions for updating: Please use Model.fit, which
supports generators.
去掉'generator='.
旧培训:
model.fit_generator(generator=train_generator,
steps_per_epoch=2048//36, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
新培训:
model.fit(train_generator,
steps_per_epoch=2048 // 128, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
我有以下代码,
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
现在model.fit_generator
定义如下:
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
现在 model.fit_generator
已弃用,在这种情况下,将 model.fit_generator
更改为 model.fit
的正确方法是什么?
您只需将 model.fit_generator()
更改为 model.fit()
。
从 TensorFlow 2.1 开始,model.fit()
也接受生成器作为输入。就这么简单。
来自TensorFlow的官方文档:
Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.
去掉'generator='.
旧培训:
model.fit_generator(generator=train_generator,
steps_per_epoch=2048//36, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
新培训:
model.fit(train_generator,
steps_per_epoch=2048 // 128, epochs=10,
validation_data=validation_generator, validation_steps=832//16)