如何在训练期间的每个时期结束时调用测试集?我正在使用张量流
How can I call a test set at the end of each epoch during the training? I am using tensorflow
我正在使用 Tensorflow-Keras 开发一个 CNN 模型,我在其中将数据集拆分为训练集、验证集和测试集。我需要在每个时期结束时调用测试集以及训练集和验证集来评估模型性能。下面是我跟踪训练集和验证集的代码。
result_dic = {"epochs": []}
json_logging_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['acc']),
'val_acc': str(logs['val_acc'])
}))
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback])
输出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333
但是,我不确定如何将测试集添加到我的回调中以产生以下输出。
预期输出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333 - test_acc: xxx
要在每个时期后显示您的测试准确性,您可以自定义 fit
函数来显示此指标。查看此 documentation or you could, as shown here,为您的测试数据集定义一个简单的回调并将其传递到您的 fit
函数中:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback,
your_test_callback((X_test, Y_test))])
如果您想要完全的灵活性,您可以尝试使用 training loop。
更新:由于您希望所有指标都使用一个 JSON,因此您应该执行以下操作:
定义您的 TestCallBack
并将您的测试准确性(如果需要,还可以添加 loss
)到您的 logs
词典中:
import tensorflow as tf
class TestCallback(tf.keras.callbacks.Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs):
x, y = self.test_data
loss, acc = self.model.evaluate(x, y, verbose=0)
logs['test_accuracy'] = acc
然后将测试准确度添加到您的结果字典中:
result_dic = {"epochs": []}
json_logging_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['accuracy']),
'val_acc': str(logs['val_accuracy']),
'test_acc': str(logs['test_accuracy'])
}))
然后在 fit
函数中使用两个回调,但注意回调的顺序:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks=[TestCallback((x_test, y_test)), json_logging_callback])
我正在使用 Tensorflow-Keras 开发一个 CNN 模型,我在其中将数据集拆分为训练集、验证集和测试集。我需要在每个时期结束时调用测试集以及训练集和验证集来评估模型性能。下面是我跟踪训练集和验证集的代码。
result_dic = {"epochs": []}
json_logging_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['acc']),
'val_acc': str(logs['val_acc'])
}))
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback])
输出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333
但是,我不确定如何将测试集添加到我的回调中以产生以下输出。
预期输出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333 - test_acc: xxx
要在每个时期后显示您的测试准确性,您可以自定义 fit
函数来显示此指标。查看此 documentation or you could, as shown here,为您的测试数据集定义一个简单的回调并将其传递到您的 fit
函数中:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback,
your_test_callback((X_test, Y_test))])
如果您想要完全的灵活性,您可以尝试使用 training loop。
更新:由于您希望所有指标都使用一个 JSON,因此您应该执行以下操作:
定义您的 TestCallBack
并将您的测试准确性(如果需要,还可以添加 loss
)到您的 logs
词典中:
import tensorflow as tf
class TestCallback(tf.keras.callbacks.Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs):
x, y = self.test_data
loss, acc = self.model.evaluate(x, y, verbose=0)
logs['test_accuracy'] = acc
然后将测试准确度添加到您的结果字典中:
result_dic = {"epochs": []}
json_logging_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['accuracy']),
'val_acc': str(logs['val_accuracy']),
'test_acc': str(logs['test_accuracy'])
}))
然后在 fit
函数中使用两个回调,但注意回调的顺序:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks=[TestCallback((x_test, y_test)), json_logging_callback])