callbacks in keras gives KeyError: 'metrics'?
callbacks in keras gives KeyError: 'metrics'?
callbacks
在 Colab 中训练时出现 KeyError: 'metrics'
数据集: SETI
pip install livelossplot
from livelossplot.tf_keras import PlotLossesCallback
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix
from sklearn import metrics
import numpy as np
np.random.seed(42)
import warnings;warnings.simplefilter('ignore')
%matplotlib inline
print('Tensorflow version:', tf.__version__)
。
.
.
.
model.compile(optimizer = optimizer, loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()
checkpoint = ModelCheckpoint("model_weights.h5", monitor='val_loss',
save_weights_only=True, mode='min', verbose=0)
my_callbacks = [PlotLossesCallback(), checkpoint]#, reduce_lr]
batch_size = 32
history = model.fit(
datagen_train.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
steps_per_epoch=len(x_train)//batch_size,
validation_data = datagen_val.flow(x_val, y_val, batch_size=batch_size, shuffle=True),
validation_steps = len(x_val)//batch_size,
epochs=50,
callbacks=my_callbacks
)
错误
KeyError Traceback (most recent call last)
<ipython-input-60-ff0dc86d079d> in <module>()
11 validation_steps = len(x_val)//batch_size,
12 epochs=12,
---> 13 callbacks=callbacks
14 )
3 frames
/usr/local/lib/python3.6/dist-packages/livelossplot/generic_keras.py in on_train_begin(self, logs)
29
30 def on_train_begin(self, logs={}):
---> 31 self.liveplot.set_metrics([metric for metric in self.params['metrics'] if not metric.startswith('val_')])
32
33 # slightly convolved due to model.complie(loss=...) stuff
KeyError: 'metrics'
您的导入使用的是旧版本 API,较新版本有一些 API 变化
只需更改您的导入语句
from livelossplot.tf_keras import PlotLossesCallback
至
from livelossplot.inputs.tf_keras import PlotLossesCallback
查看 livelossplot github 以获取更多信息和示例:
livelossplot-github
callbacks
在 Colab 中训练时出现 KeyError: 'metrics'
数据集: SETI
pip install livelossplot
from livelossplot.tf_keras import PlotLossesCallback
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix
from sklearn import metrics
import numpy as np
np.random.seed(42)
import warnings;warnings.simplefilter('ignore')
%matplotlib inline
print('Tensorflow version:', tf.__version__)
。 . . .
model.compile(optimizer = optimizer, loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()
checkpoint = ModelCheckpoint("model_weights.h5", monitor='val_loss',
save_weights_only=True, mode='min', verbose=0)
my_callbacks = [PlotLossesCallback(), checkpoint]#, reduce_lr]
batch_size = 32
history = model.fit(
datagen_train.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
steps_per_epoch=len(x_train)//batch_size,
validation_data = datagen_val.flow(x_val, y_val, batch_size=batch_size, shuffle=True),
validation_steps = len(x_val)//batch_size,
epochs=50,
callbacks=my_callbacks
)
错误
KeyError Traceback (most recent call last)
<ipython-input-60-ff0dc86d079d> in <module>()
11 validation_steps = len(x_val)//batch_size,
12 epochs=12,
---> 13 callbacks=callbacks
14 )
3 frames
/usr/local/lib/python3.6/dist-packages/livelossplot/generic_keras.py in on_train_begin(self, logs)
29
30 def on_train_begin(self, logs={}):
---> 31 self.liveplot.set_metrics([metric for metric in self.params['metrics'] if not metric.startswith('val_')])
32
33 # slightly convolved due to model.complie(loss=...) stuff
KeyError: 'metrics'
您的导入使用的是旧版本 API,较新版本有一些 API 变化
只需更改您的导入语句
from livelossplot.tf_keras import PlotLossesCallback
至
from livelossplot.inputs.tf_keras import PlotLossesCallback
查看 livelossplot github 以获取更多信息和示例: livelossplot-github