更改标签与 x 和 y 刻度的距离,并选择一种颜色作为 seaborn 热图中的注释

Changeing the labels distance from x and y ticks and choosing one color to the annotation in seaborn heat map

我正在使用 seaborn 热图做相关矩阵,我需要:

  1. 更改刻度与 x 和 y 标签之间的距离。
  2. 此外,更改标题和热图之间的距离。
  3. 统一标注颜色为白色或黑色。

我正在使用以下代码:

from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import seaborn as sns

#call data frame and apply correlation:

#data = 
#df = pd.DataFrame(data, columns = features)
#df_small = df.iloc[:,:]#if only few parameters are needed
#correlation_mat = df_small.corr()

#Create color pallete: 

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb =matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict
 
#https://www.december.com/html/spec/colorshades.html

hc = ['#e5e5ff', '#C7DDF2', '#8EBAE5', '#407FB7', '#00549F']#ffffff #e5e5ff
th = [0, 0.25, 0.5, 0.75, 1]
hc = hc[:0:-1] + hc  # prepend a reversed copy, but without repeating the central value
cdict = NonLinCdict(np.linspace(0, 1, len(hc)), hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)
corr=np.random.uniform(-1, 1, (6,6))

#plot correlation matrix:

plt.figure(figsize = (10,8))
ax=sns.heatmap(corr,center=0, linewidths=1, annot = True,cmap=cm ,square=True, vmin=-1, vmax=1,
               robust=True, annot_kws={'size':16}, cbar=True,linecolor='#F6A800',xticklabels=True,
                yticklabels=True)


cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=10, axis='both', which='both', length=0)
cbar.set_ticks(np.linspace(-1, 1, 11))
plt.title("title", y=-1.5,fontsize = 18,)
plt.xlabel("X_parameters",fontsize = 18)
plt.ylabel("Y_paramaters",fontsize = 18)
ax.tick_params(axis='both', which='both', length=0)

ax.axhline(y=0, color='#F6A800',linewidth=4)
ax.axhline(y=corr.shape[1], color='#F6A800',linewidth=4)
ax.axvline(x=0, color='#F6A800',linewidth=4)
ax.axvline(x=corr.shape[0], color='#F6A800',linewidth=4)

#change position of lables and titles and assign colors.

plt.show()

我当前的输出是:

嗯,plt.title() has a parameter pad= to set the padding between the text of the title and the top spine of the plot (default is 6). plt.xlabel()plt.ylabel() 有一个参数 labelpad= 来设置轴标签和刻度标签之间的距离。

sns.heatmap() 有一个参数 annot_kws,它是注释文本的参数字典。可以通过 sns.heatmap(..., annot_kws={'size': 16, 'color': 'black'}) 更改颜色。请注意,为了可读性,seaborn 的默认设置为深色单元格上的文本使用白色,浅色单元格上的文本使用黑色。

from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import seaborn as sns

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb = matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict

hc = ['#e5e5ff', '#C7DDF2', '#8EBAE5', '#407FB7', '#00549F']  # ffffff #e5e5ff
th = [0, 0.25, 0.5, 0.75, 1]
hc = hc[:0:-1] + hc  # prepend a reversed copy, but without repeating the central value
cdict = NonLinCdict(np.linspace(0, 1, len(hc)), hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)
corr = np.random.uniform(-1, 1, (6, 6))

# plot correlation matrix:
plt.figure(figsize=(10, 8))
ax = sns.heatmap(corr, center=0, linewidths=1, annot=True, cmap=cm, square=True, vmin=-1, vmax=1,
                 robust=True, annot_kws={'size': 16, 'color': 'black'}, cbar=True, linecolor='#F6A800',
                 xticklabels=True, yticklabels=True)

cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=10, axis='both', which='both', length=0)
cbar.set_ticks(np.linspace(-1, 1, 11))
plt.title("title", y=-1.5, fontsize=18, pad=15)
plt.xlabel("X_parameters", fontsize=18, labelpad=15)
plt.ylabel("Y_paramaters", fontsize=18, labelpad=15)
ax.tick_params(axis='both', which='both', length=0)

ax.axhline(y=0, color='#F6A800', linewidth=4)
ax.axhline(y=corr.shape[1], color='#F6A800', linewidth=4)
ax.axvline(x=0, color='#F6A800', linewidth=4)
ax.axvline(x=corr.shape[0], color='#F6A800', linewidth=4)

plt.show()