使用 seaborn 绘制热图时如何将刻度定位到网格中心?

How to locate the ticks to center of grid when plotting heatmap with seaborn?

我用 seaborn 包绘制了热图,我希望刻度数位于网格之间的中心。我应该怎么做才能移动刻度线?

此外,我认为第一个图和第二个图之间的 space 太窄,而第二​​个图和颜色条之间的 space 太宽。我应该如何调整每个轴之间的space?

下面是我的绘图代码,我附上了结果图。

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

sns.set(font_scale=1.2)
fig, axs=plt.subplots(ncols=3, figsize=(24, 10), gridspec_kw=dict(width_ratios=[4,4,0.3]))
df1=pd.DataFrame(main_pmatrix)
df2=pd.DataFrame(after_pmatrix)
g1=sns.heatmap(df1, cmap='hot', ax=axs[0], cbar=False, vmin=0.9, vmax=1.8)
g2=sns.heatmap(df2, cmap='hot', ax=axs[1], cbar=False, vmin=0.9, vmax=1.8)
g1.set_xticks(range(21))
g1.set_xticklabels(['1.0','1.1','1.2','1.3','1.4','1.5','1.6','1.7','1.8','1.9','2.0','2.1','2.2','2.3','2.4','2.5','2.6','2.7','2.8','2.9','3.0'])
g2.set_xticks(range(21))
g2.set_xticklabels(['1.0','1.1','1.2','1.3','1.4','1.5','1.6','1.7','1.8','1.9','2.0','2.1','2.2','2.3','2.4','2.5','2.6','2.7','2.8','2.9','3.0'])
g1.set_yticks(range(11))
g1.set_yticklabels(['0.5','0.6','0.7','0.8','0.9','1.0','1.1','1.2','1.3','1.4','1.5'])
g1.set_xlabel("Fractal dimension, $d_f$", fontsize=25, labelpad=10)
g1.set_ylabel("$b-$value", fontsize=25, labelpad=10)
g2.set_yticks(range(11))
g2.set_yticklabels(['0.5','0.6','0.7','0.8','0.9','1.0','1.1','1.2','1.3','1.4','1.5'])
g2.set_xlabel("Fractal dimension, $d_f$", fontsize=25, labelpad=10)
g2.set_ylabel("$b-$value", fontsize=25, labelpad=10)
plt.colorbar(axs[1].collections[0], cax=axs[2])
plt.xlabel("$p$", fontsize=20, labelpad=10)
plt.show()

您可以使用 sns.heatmap to set the desired tick labels at the correct positions. That way, the set_xticks etc. aren't needed. The vertical space between the subplots can be adjusted via the wspace of the gridspec_kwxticklabels=yticklabels= 参数。

另请注意,不需要将矩阵转换为 pandas 数据帧,因为 seaborn 只会将数据帧转换回 2D numpy 数组。

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

main_pmatrix = np.random.uniform(1.2, 1.8, (11, 21))
after_pmatrix = np.random.uniform(0.9, 1.5, (11, 21))
sns.set(font_scale=1.2)
fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(24, 10), gridspec_kw=dict(width_ratios=[4, 4, 0.3], wspace=0.2))
xticklabels = [f'{i:.1f}' for i in np.arange(1, 3.001, 0.1)]
yticklabels = [f'{i:.1f}' for i in np.arange(0.5, 1.501, 0.1)]
sns.heatmap(main_pmatrix, cmap='hot', ax=ax1, cbar=False, vmin=0.9, vmax=1.8,
            xticklabels=xticklabels, yticklabels=yticklabels)
sns.heatmap(after_pmatrix, cmap='hot', ax=ax2, cbar=False, vmin=0.9, vmax=1.8,
            xticklabels=xticklabels, yticklabels=yticklabels)

for ax in (ax1, ax2):
    ax.set_xlabel("Fractal dimension, $d_f$", fontsize=25, labelpad=10)
    ax.set_ylabel("$b-$value", fontsize=25, labelpad=10)
ax1.set_title('Main p-matrix', fontsize=28)
ax2.set_title('After p-matrix', fontsize=28)
plt.colorbar(ax2.collections[0], cax=ax3)
ax3.set_xlabel("$p$", fontsize=20, labelpad=10)
plt.show()