保存模型、加载模型并继续训练的最佳方式是什么?

What is the best way of saving a model, Loading it and continuing training?

我在这里浏览了很多答案,但发现它们部分正确或无用。

我有一个 keras/tensorflow 模型需要训练。在这次训练中,我的模型是

我该怎么办?

根据您的需要,您可以使用 model.save() or use a ModelCheckPoint callback.

我找到了一个简单的方法。虽然这种方法最初是由其他用户提出的,但总是有寻求者抱怨,关于模型训练的重置以及他们加载的模型正在以非常低的验证准确度进行训练(而不是他们在最后一天停止的地方) . that.I 想用一个例子来证明这一点:

假设我将模型定义为:

def get_model_classif_nasnet():
    inputs = Input((224, 224, 3))
    #Other layers not shown here...
    model = Model(inputs, out)
    model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=['acc'])
    model.summary()
    return model

我们将在每个时期后保存我们的模型进度。为此,我们使用检查点。如果不同保存的检查点有相关的名称,我们也会很高兴。(即,名称应该详细说明它经历的训练环境)

h5_path = "weights-improvement-{epoch:02d}-{val_loss:.4f}-{val_acc:.2f}.h5"

checkpoint = ModelCheckpoint(h5_path,
                             monitor='val_acc',
                             verbose=1,
                             save_best_only=True,
                             mode='max'
                            )

现在,让我们用上面的知识来保存一个模型吧 -

1)训练并保存模型

2)加载它

3)继续训练

1)

#Initialize a model
old_model = get_model_classif_nasnet()

#Let's train it
 batch_size = 32

 history = old_model.fit_generator(
    #Training and Validation data...
    epochs=2, verbose=1,
    callbacks=[checkpoint],
    #Some other parameters (not necessarily present in your method)
    steps_per_epoch = len(train) // batch_size,
    validation_steps=len(val) // batch_size
)

你的进度应该是这样的:

注意检查点在第 1 个时期后被保存。现在假设我们在第二个时期中途结束训练。所以我们有一张 checkpoint/model 图像保存为 .h5 文件

2)加载它

#Again initialize a model
new_model = get_model_classif_nasnet()

3) 继续训练

#There is nothing new here
batch_size = 32

 history = new_model.fit_generator(
    #Training and Validation data...
    epochs=8, verbose=1,
    callbacks=[checkpoint],
    #Some other parameters (not necessarily present in your method)
    steps_per_epoch = len(train) // batch_size,
    validation_steps=len(val) // batch_size
)

所以,就是这样。最重要的是,即使在完成所有这些之后,您也需要确保从一开始就保持低学习率 也许您注意到了 optimizer=Adam(0.0001) 这是这里的关键。一位用户引用 "It is happening because model.save(filename.h5) does not save the state of the optimizer. "