Seaborn 热图适合单元格的注释文本

Seaborn heatmap fit annotation text to cell

我有这段显示混淆矩阵的代码。在每个单元格中,首先显示准确度,然后在其下方显示正确预测的 samples/total 个样本数。现在我想显示每个单元格内的所有文本。例如,第一个单元格应在精度下显示 186/208。 如何在单元格内显示注释的全文?我尝试减小字体大小,但没有用。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def cm_analysis(cm, labels, figsize=(20,15)):
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float)
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.2f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.2f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Groundtruth labels'
    cm.columns.name = 'Predicted labels'
    fig, ax = plt.subplots(figsize=figsize)
    ax.axhline(color='black')

    g =sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax, cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
    g.set_xticklabels(g.get_xticklabels(), rotation = 45)
    sns.set(font_scale=1.1)
    plt.savefig("filename.png")

normalised_confusion_matrix  = np.array(
[[186,3,0,1,2,0,3,3,7,1,2,0,0],
 [5,9,1,0,3,0,0,0,0,0,0,0,1],
 [0,0,49,3,0,0,0,0,1,0,0,0,6],
 [1,0,6,89,0,0,0,0,1,1,1,0,1],
 [3,7,0,0,50,0,0,0,6,0,1,0,0],
 [1,0,0,0,0,9,0,1,0,0,0,0,0],
 [3,0,1,0,0,0,54,0,0,0,3,0,0],
 [2,0,0,0,0,0,2,7,0,0,0,0,0],
 [3,0,0,0,2,1,2,0,53,2,4,0,0],
 [0,0,0,1,0,1,0,0,1,7,0,1,0],
 [1,1,0,0,1,0,1,0,3,0,52,0,0],
 [1,0,0,0,0,0,0,0,1,0,0,5,0],
 [0,0,11,2,0,0,0,0,0,0,0,0,26]]
)

classes = ['Assemble system','Consult sheets','Picking in front','Picking left','Put down component','Put down measuring rod','Put down screwdriver','Put down subsystem','Take component','Take measuring rod','Take screwdriver','Take subsystem','Turn sheets']

    
cm_analysis(cm= normalised_confusion_matrix, labels = classes)

主要问题是将 annot 数组创建为类型 str 而不是 object(因此,annot = np.empty_like(cm).astype(object))。使用 str 类型会导致奇怪的错误,因为 numpy 字符串有一些内置的最大长度。 (另见 。)

由于在cm_sum[i]中只使用一个索引,所以最好不要在cm_sum = np.sum(cm, axis=1, keepdims=False)中“保持维度”(docs)。

另请注意,对于百分比,您需要乘以 100。(创建格式化字符串的现代方法是使用 f-stringsannot[i, j] = f'{p*100:.2f}%\n{c}/{s}')。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def cm_analysis(cm, labels, figsize=(20, 15)):
    cm_sum = np.sum(cm, axis=1, keepdims=False)
    cm_perc = cm / cm_sum.astype(float)
    annot = np.empty_like(cm).astype(object)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = f'{p*100:.1f}%\n{c}/{s}'
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = f'{p*100:.1f}%\n{c}'
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Groundtruth labels'
    cm.columns.name = 'Predicted labels'
    fig, ax = plt.subplots(figsize=figsize)
    ax.axhline(color='black')

    g = sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax,
                    cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
    g.set_xticklabels(g.get_xticklabels(), rotation=45)
    sns.set(font_scale=1.1)
    plt.savefig("filename.png")

normalised_confusion_matrix = np.array(
    [[186, 3, 0, 1, 2, 0, 3, 3, 7, 1, 2, 0, 0],
     [5, 9, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1],
     [0, 0, 49, 3, 0, 0, 0, 0, 1, 0, 0, 0, 6],
     [1, 0, 6, 89, 0, 0, 0, 0, 1, 1, 1, 0, 1],
     [3, 7, 0, 0, 50, 0, 0, 0, 6, 0, 1, 0, 0],
     [1, 0, 0, 0, 0, 9, 0, 1, 0, 0, 0, 0, 0],
     [3, 0, 1, 0, 0, 0, 54, 0, 0, 0, 3, 0, 0],
     [2, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0],
     [3, 0, 0, 0, 2, 1, 2, 0, 53, 2, 4, 0, 0],
     [0, 0, 0, 1, 0, 1, 0, 0, 1, 7, 0, 1, 0],
     [1, 1, 0, 0, 1, 0, 1, 0, 3, 0, 52, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 5, 0],
     [0, 0, 11, 2, 0, 0, 0, 0, 0, 0, 0, 0, 26]]
)

classes = ['Assemble system', 'Consult sheets', 'Picking in front', 'Picking left', 'Put down component',
           'Put down measuring rod', 'Put down screwdriver', 'Put down subsystem', 'Take component',
           'Take measuring rod', 'Take screwdriver', 'Take subsystem', 'Turn sheets']

cm_analysis(cm=normalised_confusion_matrix, labels=classes)