为多类分类可视化 predict_proba

visualize predict_proba for multiclass classification

使用 model.predict_proba(X) 我得到了一个包含很多数字的大数组。

我正在寻找一种方法来可视化所有 类(在我的例子中是 13)的分类概率。我使用 RandomForestClassifier.

有什么推荐吗?

热图是可视化二维矩阵的好方法。当然,如果你的 X 中的记录数量很大,那么很难一次将所有内容可视化。否则您可能必须对记录进行采样。在这里,我展示了前 10 条记录的视觉效果,如果预测概率大于 0.1,则标记为预测 类。

看看这个例子:

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np


X, y = make_classification(n_samples=10000,n_features=40,
                           n_informative=30, n_classes=13,
                           n_redundant=0, n_clusters_per_class=1,
                           random_state=42)


X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=42)

forest = RandomForestClassifier(n_estimators=10, random_state=42).fit(X_train, y_train)

pred = forest.predict_proba(X_test)[:10]
fig, ax = plt.subplots(figsize= (20,8))
im = ax.imshow(pred, cmap='Blues')

ax.grid(axis='y')
ax.set_xticklabels([])

ax.set_yticks(np.arange(pred.shape[0]))

plt.ylabel('Records', fontsize='xx-large')
plt.xlabel('Classes', fontsize='xx-large')
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) 

for i in range(pred.shape[0]):
    for j in range(13):
        if pred[i, j] >.1:
             ax.text(j, i, j,
                     ha="center", va="center", color="w", fontsize=30)

如果您的输入 space 是二维的,或者如果您使用一些降维技术将其嵌入到二维中,您可以绘制多类决策面:

# generate toy data
X, y = sklearn.datasets.make_blobs(n_samples=1000, centers=13)

# fit classifier
clf = sklearn.ensemble.RandomForestClassifier().fit(X, y)

# create decision surface
xx, yy = np.meshgrid(np.linspace(-13, 12, 100),
                     np.linspace(-13, 12, 100))
Z = clf.predict(np.array([xx.ravel(), yy.ravel()]).T)
Z = Z.reshape(xx.shape)

fig, ax = plt.subplots(1,1, figsize=(8,8))
ax.scatter(X[:,0], X[:,1], c=y, cmap='Paired')
ax.contourf(xx, yy, Z, cmap='Paired', alpha=0.5)

请注意,这只是每个标签的阴影(predict 不是 predict_proba),但您可以根据概率将其扩展为不同的阴影。