callbacks in keras gives KeyError: 'metrics'?

callbacks in keras gives KeyError: 'metrics'?

callbacks 在 Colab 中训练时出现 KeyError: 'metrics'

数据集: SETI

pip install livelossplot
from livelossplot.tf_keras import PlotLossesCallback
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from sklearn.metrics import confusion_matrix
from sklearn import metrics

import numpy as np
np.random.seed(42)
import warnings;warnings.simplefilter('ignore')
%matplotlib inline
print('Tensorflow version:', tf.__version__)

。 . . .

model.compile(optimizer = optimizer, loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()

checkpoint = ModelCheckpoint("model_weights.h5", monitor='val_loss',
                             save_weights_only=True, mode='min', verbose=0)

my_callbacks = [PlotLossesCallback(), checkpoint]#, reduce_lr]

batch_size = 32
history = model.fit(
    datagen_train.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
    steps_per_epoch=len(x_train)//batch_size,
    validation_data = datagen_val.flow(x_val, y_val, batch_size=batch_size, shuffle=True),
    validation_steps = len(x_val)//batch_size,
    epochs=50,
    callbacks=my_callbacks
)

错误

KeyError                                  Traceback (most recent call last)
<ipython-input-60-ff0dc86d079d> in <module>()
     11     validation_steps = len(x_val)//batch_size,
     12     epochs=12,
---> 13     callbacks=callbacks
     14 )

3 frames
/usr/local/lib/python3.6/dist-packages/livelossplot/generic_keras.py in on_train_begin(self, logs)
     29 
     30     def on_train_begin(self, logs={}):
---> 31         self.liveplot.set_metrics([metric for metric in self.params['metrics'] if not metric.startswith('val_')])
     32 
     33         # slightly convolved due to model.complie(loss=...) stuff

KeyError: 'metrics'

您的导入使用的是旧版本 API,较新版本有一些 API 变化

只需更改您的导入语句

from livelossplot.tf_keras import PlotLossesCallback

from livelossplot.inputs.tf_keras import PlotLossesCallback

查看 livelossplot github 以获取更多信息和示例: livelossplot-github