获得多类分类的最差预测 类
Get worst predicted classes of multiclass classification
我正在研究一个多class class化问题,我有很多不同的 classes (50+)。
问题是,我想突出显示最差的预测 classes(例如,在混淆矩阵或其他方面),以便在我的 classifier 中做一些进一步的调整。
我的预测和测试数据保存在一个列表中(来自 sklearn 的小例子):
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
如何从矩阵中得到本例中的 class 1?那里的预测是完全错误的。有没有办法根据 classes 的真阳性预测对它们进行排序?
您可以使用 scikit-learn 中的 classifiction_report,它将 return 字典,具有精确度、召回率和 F 分数。然后你可以按排序方式打印字典,这样你就可以很容易地看到最差的预测 class。
#prints classification_report
print(classification_report(y_true, y_pred)
#returns a dict, which you can easily sort by prediction
report = classification_report(y_true, y_pred, output_dict=True)
您可以为此使用一个简单的函数:
def print_class_accuracies(confusion_matrix):
# get the number of occurrences for each class
counts = {cl: y_true.count(cl) for cl in set(y_true)}
# extract the diagonal values (true positives)
tps = dict(enumerate(conf.diagonal()))
# Get the accuracy for each class, preventing ZeroDivisionErrors
pred_accuracy = {cl: tps[cl]/counts.get(cl, 1) for cl in tps}
# Get a ranking, worst accuracies are first/lowest
ranking = sorted([(acc,cl) for cl, acc in pred_accuracy.items()])
# Pretty print it
for acc, cl in ranking:
print(f"Class {cl}: accuracy: {acc:.2f}")
我正在研究一个多class class化问题,我有很多不同的 classes (50+)。
问题是,我想突出显示最差的预测 classes(例如,在混淆矩阵或其他方面),以便在我的 classifier 中做一些进一步的调整。
我的预测和测试数据保存在一个列表中(来自 sklearn 的小例子):
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
如何从矩阵中得到本例中的 class 1?那里的预测是完全错误的。有没有办法根据 classes 的真阳性预测对它们进行排序?
您可以使用 scikit-learn 中的 classifiction_report,它将 return 字典,具有精确度、召回率和 F 分数。然后你可以按排序方式打印字典,这样你就可以很容易地看到最差的预测 class。
#prints classification_report
print(classification_report(y_true, y_pred)
#returns a dict, which you can easily sort by prediction
report = classification_report(y_true, y_pred, output_dict=True)
您可以为此使用一个简单的函数:
def print_class_accuracies(confusion_matrix):
# get the number of occurrences for each class
counts = {cl: y_true.count(cl) for cl in set(y_true)}
# extract the diagonal values (true positives)
tps = dict(enumerate(conf.diagonal()))
# Get the accuracy for each class, preventing ZeroDivisionErrors
pred_accuracy = {cl: tps[cl]/counts.get(cl, 1) for cl in tps}
# Get a ranking, worst accuracies are first/lowest
ranking = sorted([(acc,cl) for cl, acc in pred_accuracy.items()])
# Pretty print it
for acc, cl in ranking:
print(f"Class {cl}: accuracy: {acc:.2f}")