如何在 seaborn 中为不同样式的线图和散点图创建正确的图例条目

How to create correct legend entries in seaborn for line plots of different styles and a scatter plot

我有一个问题,如何将三个具有不同线型的图组合起来并相应地调整图例。我有两个折线图,每个折线图包含两条线。一个是实线,另一个是虚线。在顶部我添加了一个散点图。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import numpy as np
#sample data generation
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"], data=np.array([[1, 2],[2.4, 5],[4.1, 7.1],[5, 8.9],[5.2, 10]]))
plot_df_dash = pd.DataFrame(index=np.arange(5), columns=["Series 1 dashed", "Series 2 dashed "], data=np.array([[2, 3],[3.4, 4],[5.1, 6.1],[7, 1.9],[4.2, 12]]))
plot_df_points = pd.DataFrame(index = [1.5, 2, 3.7], columns = ["Series 1", "Series 2"], data=np.array([[1.2, 3.4],[4.5, 6.9],[5.5, 9.6]]))
df = pd.DataFrame(plot_df.stack()).reset_index()
df_dash = pd.DataFrame(plot_df.stack()).reset_index()
df.columns = ["x", "Series","y"]
df_dash.columns=["x", "Series dashed","y"]
df_points = pd.DataFrame(plot_df_points.stack()).reset_index()
df_points.columns = ["x", "Series","y"]

#plotting
fig, ax = plt.subplots()
sns.lineplot(data=df,x="x",y="y", hue="Series",ax=ax,palette="rocket",linewidth=2.5)
sns.lineplot(data=df_dash,x="x",y="y", hue="Series dashed",ax=ax,palette="rocket",linewidth=2.5,linestyle="--")
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax,s=200)

#generating legend
handles, labels = ax.get_legend_handles_labels()
ax.legend([tuple(handles[::2]), tuple(handles[1::2])], labels[:4], handlelength=3,
          handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()
plt.close()

我正在努力寻找正确的图例。散点图应该用于虚线和实线图。这意味着图例应显示 4 个条目,两条带圆圈的实线来自散点图,两条带圆圈的虚线来自同一散点图。

正如我所说,一种方法是在图例句柄中手动设置正确的线条属性。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import numpy as np
#sample data generation
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"], data=np.array([[1, 2],[2.4, 5],[4.1, 7.1],[5, 8.9],[5.2, 10]]))
plot_df_dash = pd.DataFrame(index=np.arange(5), columns=["Series 1 dashed", "Series 2 dashed "], data=np.array([[2, 3],[3.4, 4],[5.1, 6.1],[7, 1.9],[4.2, 12]]))
plot_df_points = pd.DataFrame(index = [1.5, 2, 3.7], columns = ["Series 1", "Series 2"], data=np.array([[1.2, 3.4],[4.5, 6.9],[5.5, 9.6]]))
df = pd.DataFrame(plot_df.stack()).reset_index()
#changed the dataframe generation here - the reason why you did not see dashed lines
df_dash = pd.DataFrame(plot_df_dash.stack()).reset_index()
df.columns = ["x", "Series","y"]
df_dash.columns=["x", "Series dashed","y"]
df_points = pd.DataFrame(plot_df_points.stack()).reset_index()
df_points.columns = ["x", "Series","y"]

#plotting
fig, ax = plt.subplots()
#same color palette for the series
sns.color_palette("rocket")
#defining linestyle and width for dashed line
ls = "--"
lw = 3.5
sns.lineplot(data=df, x="x", y="y", hue="Series", ax=ax, linewidth=2.5)
sns.lineplot(data=df_dash, x="x", y="y", hue="Series dashed", ax=ax, linewidth=lw, linestyle=ls)
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax, s=200)

#generating legend
handles, labels = ax.get_legend_handles_labels()
#manipulating appearance of wrongly generated seaborn line2D objects for dashed lines
for i in [2, 3]:
    handles[i].set_linestyle(ls) 
    handles[i].set_linewidth(lw) 
#generate legend entries as suggested by you
ax.legend([tuple([handles[0], handles[4]]), 
           tuple([handles[1], handles[5]]), 
           tuple([handles[2], handles[4]]), 
           tuple([handles[3], handles[5]])], 
           labels[:4], handlelength=7, 
           handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()

示例输出:

顺便说一句,原来你没有看到虚线,因为你错误地将 df 值归因于 df_dash。

如果您只想生成系列条目,因为线条属性通常在图形图例中进行解释,代码将简化为:

#generating legend
handles, labels = ax.get_legend_handles_labels()
for i in [2, 3]:
    handles[i].set_linestyle(ls) 
    handles[i].set_linewidth(lw)         

ax.legend([tuple(handles[::2]), tuple(handles[1::2])], labels[:2], handlelength=10,
          handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()