从 GridSearchCV 模型中获取每个 epoch 的验证分数
Get each epoch's validation scores from GridSearchCV models
我正在将 GridSearchCV 与 keras 一起使用,我想绘制和分析训练与验证历史记录。但是,我已经检查了文档并真正搜索了 SO,但是当使用 GridSearchCV 拟合模型时,我找不到获取验证历史记录(即每个时期的分数)的方法。我能够在回调中获得训练历史,但不能获得验证历史。问题是有些模型过度拟合,我希望能够看到调整参数如何影响过度拟合。
我是这样使用 GridSearchCV 的:
class MyCallback(keras_callbacks.Callback):
def on_train_end(self, logs=None):
# here I can get the model history from self.model.history.history
def create_model(...):
...
model = Model(...)
model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=['acc'])
return model
callbacks = [MyCallback()]
model = KerasClassifier(build_fn=create_model, verbose=3)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=4, cv=3, verbose=10, return_train_score=True)
grid_result = grid.fit(X_train["padded"], y_train["binary"], epochs=30, batch_size=16, callbacks=callbacks)
您希望在拟合 Keras 模型时跟踪验证性能,例如使用 validation_data
或 validation_split
(请参阅 here 以获取参考)。
然而GridSearchCV
(来自sklearn)并不是很聪明地理解验证集(在CV拆分期间创建)必须与KerasClassifier
一起使用作为validation_data
才能跟踪scores/losses 每个纪元。
换句话说,您无法使用 GridSearchCV
跟踪每个验证集(在 CV 拆分期间创建)的性能
我正在将 GridSearchCV 与 keras 一起使用,我想绘制和分析训练与验证历史记录。但是,我已经检查了文档并真正搜索了 SO,但是当使用 GridSearchCV 拟合模型时,我找不到获取验证历史记录(即每个时期的分数)的方法。我能够在回调中获得训练历史,但不能获得验证历史。问题是有些模型过度拟合,我希望能够看到调整参数如何影响过度拟合。
我是这样使用 GridSearchCV 的:
class MyCallback(keras_callbacks.Callback):
def on_train_end(self, logs=None):
# here I can get the model history from self.model.history.history
def create_model(...):
...
model = Model(...)
model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=['acc'])
return model
callbacks = [MyCallback()]
model = KerasClassifier(build_fn=create_model, verbose=3)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=4, cv=3, verbose=10, return_train_score=True)
grid_result = grid.fit(X_train["padded"], y_train["binary"], epochs=30, batch_size=16, callbacks=callbacks)
您希望在拟合 Keras 模型时跟踪验证性能,例如使用 validation_data
或 validation_split
(请参阅 here 以获取参考)。
然而GridSearchCV
(来自sklearn)并不是很聪明地理解验证集(在CV拆分期间创建)必须与KerasClassifier
一起使用作为validation_data
才能跟踪scores/losses 每个纪元。
换句话说,您无法使用 GridSearchCV