如何仅注释 seaborn 热图的对角线元素
How to annotate only the diagonal elements of a seaborn heatmap
我正在使用 Seaborn 热图绘制大型混淆矩阵的输出。由于对角线元素代表正确的预测,因此它们更重要的是显示 number/correct 率。如问题所示,如何仅注释热图中的对角线条目?
我已经查阅了这个网站 https://seaborn.pydata.org/examples/many_pairwise_correlations.html,但它对如何仅注释对角线条目没有帮助。希望有人能帮忙。提前致谢!
这是否有助于您实现您的想法?您给出的 URL 示例没有对角线,我在主对角线下方注释了对角线。要注释您的混淆矩阵对角线,您可以通过将 np.diag(..., -1)
中的 -1 值更改为 0 来适应我的代码。
请注意我在 sns.heatmap(...)
中添加的附加参数 fmt=''
因为我的 annot
矩阵元素是字符串。
代码
from string import ascii_letters
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="white")
# Generate a large random dataset
rs = np.random.RandomState(33)
y = rs.normal(size=(100, 26))
d = pd.DataFrame(data=y, columns=list(ascii_letters[26:]))
# Compute the correlation matrix
corr = d.corr()
# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype='bool')
mask[np.triu_indices_from(mask)] = True
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Generate the annotation
annot = np.diag(np.diag(corr.values,-1),-1)
annot = np.round(annot,2)
annot = annot.astype('str')
annot[annot=='0.0']=''
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=annot, fmt='')
plt.show()
输出
在相关问题中,有人问如何用字符串注释对角线元素。这是一个例子:
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
flights = sns.load_dataset('flights')
flights = flights.pivot('year', 'month', 'passengers')
corr_data = np.corrcoef(flights.to_numpy())
up_triang = np.triu(np.ones_like(corr_data)).astype(bool)
ax = sns.heatmap(corr_data, cmap='flare', xticklabels=False, yticklabels=False, square=True,
linecolor='white', linewidths=0.5,
cbar=True, mask=up_triang, cbar_kws={'shrink': 0.6, 'pad': 0.02, 'label': 'correlation'})
ax.invert_xaxis()
for i, label in enumerate(flights.index):
ax.text(i + 0.2, i + 0.5, label, ha='right', va='center')
plt.show()
我正在使用 Seaborn 热图绘制大型混淆矩阵的输出。由于对角线元素代表正确的预测,因此它们更重要的是显示 number/correct 率。如问题所示,如何仅注释热图中的对角线条目?
我已经查阅了这个网站 https://seaborn.pydata.org/examples/many_pairwise_correlations.html,但它对如何仅注释对角线条目没有帮助。希望有人能帮忙。提前致谢!
这是否有助于您实现您的想法?您给出的 URL 示例没有对角线,我在主对角线下方注释了对角线。要注释您的混淆矩阵对角线,您可以通过将 np.diag(..., -1)
中的 -1 值更改为 0 来适应我的代码。
请注意我在 sns.heatmap(...)
中添加的附加参数 fmt=''
因为我的 annot
矩阵元素是字符串。
代码
from string import ascii_letters
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="white")
# Generate a large random dataset
rs = np.random.RandomState(33)
y = rs.normal(size=(100, 26))
d = pd.DataFrame(data=y, columns=list(ascii_letters[26:]))
# Compute the correlation matrix
corr = d.corr()
# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype='bool')
mask[np.triu_indices_from(mask)] = True
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Generate the annotation
annot = np.diag(np.diag(corr.values,-1),-1)
annot = np.round(annot,2)
annot = annot.astype('str')
annot[annot=='0.0']=''
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=annot, fmt='')
plt.show()
输出
在相关问题中,有人问如何用字符串注释对角线元素。这是一个例子:
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
flights = sns.load_dataset('flights')
flights = flights.pivot('year', 'month', 'passengers')
corr_data = np.corrcoef(flights.to_numpy())
up_triang = np.triu(np.ones_like(corr_data)).astype(bool)
ax = sns.heatmap(corr_data, cmap='flare', xticklabels=False, yticklabels=False, square=True,
linecolor='white', linewidths=0.5,
cbar=True, mask=up_triang, cbar_kws={'shrink': 0.6, 'pad': 0.02, 'label': 'correlation'})
ax.invert_xaxis()
for i, label in enumerate(flights.index):
ax.text(i + 0.2, i + 0.5, label, ha='right', va='center')
plt.show()