Matplotlib:正确放置子图注释?

Matplotlib: correct placement of subplot annotations?

我想通过遍历我的数据用 Matplotlib 创建子图。但是,我没有将 注释放到正确的位置 ,显然甚至没有放到正确的子图中。此外,常见的 x 轴和 y 轴标签 不起作用。

我的真实数据更复杂,但这里有一个重现错误的示例:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
    distributions.append(np.random.normal(0, 0.5, 100))
    first_values.append(np.random.uniform(0.7, 1))
    second_values.append(np.random.uniform(0.7, 1))

# create subplot
fig, axes = plt.subplots(2, 2, figsize = (15, 10))
legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
                   Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
                   Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]

# loop over data and create subplots    
for data in range(4):
    if data == 0:
        position = axes[0, 0]
    if data == 1:
        position = axes[0, 1]
    if data == 2:
        position = axes[1, 0]
    if data == 3:
        position = axes[1, 1]
        
    dist = distributions[data]
    first = first_values[data]
    second = second_values[data]
        
    sns.histplot(dist, alpha = 0.5, kde = True, stat = 'density', bins = 20, color = '#76A29F', ax = position)
    sns.rugplot(dist, alpha = 0.5, color = '#76A29F', ax = position)
    position.annotate(f'{np.mean(dist):.2f}', (np.mean(dist), 0.825), xycoords = ('data', 'figure fraction'), color = '#76A29F')
    position.axvline(first, 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
    position.axvline(second, 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
    position.annotate(f'{first:.2f}', (first, 0.8), xycoords = ('data', 'figure fraction'), color = '#FEB302')
    position.annotate(f'{second:.2f}', (second, 0.85), xycoords = ('data', 'figure fraction'), color = '#FF5D3E')
    position.set_xticks(np.arange(round(min(dist), 1) - 0.1, round(max(max(dist), max([first]), max([second])), 1) + 0.1, 0.1))

plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.5, 0.5))
plt.show()

结果图如下所示:

我想要的是拥有

  1. 垂直线旁边正确子图中的注释/分布的平均值
  2. 所有子图或至少每行/列的共享 x 和 y 标签

非常感谢任何帮助!

如果使用函数将subplot做成单数组(axes.flatten()),修改为顺序绘制图形,就可以绘制图形。出于测试目的,注释的颜色已部分更改。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

np.random.seed(202000104)
# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
    distributions.append(np.random.normal(0, 0.5, 100))
    first_values.append(np.random.uniform(0.7, 1))
    second_values.append(np.random.uniform(0.7, 1))
    
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
                   Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
                   Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]

for i,ax in enumerate(axes.flatten()):
    sns.histplot(distributions[i], alpha=0.5, kde=True, stat='density', bins=20, color='#76A29F', ax=ax)
    sns.rugplot(distributions[i], alpha=0.5, color='#76A29F', ax=ax)
    ax.annotate(f'{np.mean(distributions[i]):.2f}', (np.mean(distributions[i]), 0.825), xycoords='data', color='red')
    ax.axvline(first_values[i], 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
    ax.axvline(second_values[i], 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
    ax.annotate(f'{first_values[i]:.2f}', (first_values[i], 0.8), xycoords='data', color='#FEB302')
    ax.annotate(f'{second_values[i]:.2f}', (second_values[i], 0.85), xycoords='data', color = '#FF5D3E')
    ax.set_xticks(np.arange(round(min(distributions[i]), 1) - 0.1, round(max(max(distributions[i]), max([first_values[i]]), max([second_values[i]])), 1) + 0.1, 0.1))

plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.35, 0.5))

plt.show()