Keras 将预测的准确度计算为 +/- 1
Keras calculate accuracy as +/- 1 from predicted
我正在使用 Keras (+TensorFlow) 构建深度神经网络模型。在模型中,我需要定义自己的精度函数。
比方说,模型预测完成一项工作所花费的时间(以分钟为单位,介于 0 到 20 之间)。如果预测输出在 +/- 2 以内,我希望模型打印出准确度。如果预测输出是 x 分钟,而预期输出是 x+1,我想认为这是一个正确的预测,如果预期输出是x+3,我想认为这是一个错误的预测。
这与 top_k_categorical_accuracy
略有不同
您可以使用 Keras 后端 api 轻松实现逻辑..这也将确保您的指标在 tensorflow 和 theano 上都能正常工作。
这里有测试:
import numpy as np
import keras
from keras import backend as K
shift = 2
def custom_metric(y_true,y_pred):
diff = K.abs(K.argmax(y_true, axis=-1) - K.argmax(y_pred, axis=-1))
return K.mean(K.lesser_equal(diff, shift))
t1 = np.asarray([ [0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
])
p1 = np.asarray([ [0,0,0,0,0,1,0,0,0,],
[0,0,0,0,1,0,0,0,0,],
[0,0,0,0,0,0,0,1,0,],
[0,0,0,0,0,0,0,0,1,],
[1,0,0,0,0,0,0,0,0,],
[0,0,0,0,0,0,1,0,0,],
])
print K.eval(keras.metrics.categorical_accuracy(K.variable(t1),K.variable(p1)))
print K.eval(custom_metric(K.variable(t1),K.variable(p1)))
现在在您的 compile
语句中使用它:metrics=custom_metric
我正在使用 Keras (+TensorFlow) 构建深度神经网络模型。在模型中,我需要定义自己的精度函数。
比方说,模型预测完成一项工作所花费的时间(以分钟为单位,介于 0 到 20 之间)。如果预测输出在 +/- 2 以内,我希望模型打印出准确度。如果预测输出是 x 分钟,而预期输出是 x+1,我想认为这是一个正确的预测,如果预期输出是x+3,我想认为这是一个错误的预测。
这与 top_k_categorical_accuracy
您可以使用 Keras 后端 api 轻松实现逻辑..这也将确保您的指标在 tensorflow 和 theano 上都能正常工作。
这里有测试:
import numpy as np
import keras
from keras import backend as K
shift = 2
def custom_metric(y_true,y_pred):
diff = K.abs(K.argmax(y_true, axis=-1) - K.argmax(y_pred, axis=-1))
return K.mean(K.lesser_equal(diff, shift))
t1 = np.asarray([ [0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
[0,0,0,0,0,0,1,0,0,],
])
p1 = np.asarray([ [0,0,0,0,0,1,0,0,0,],
[0,0,0,0,1,0,0,0,0,],
[0,0,0,0,0,0,0,1,0,],
[0,0,0,0,0,0,0,0,1,],
[1,0,0,0,0,0,0,0,0,],
[0,0,0,0,0,0,1,0,0,],
])
print K.eval(keras.metrics.categorical_accuracy(K.variable(t1),K.variable(p1)))
print K.eval(custom_metric(K.variable(t1),K.variable(p1)))
现在在您的 compile
语句中使用它:metrics=custom_metric