如何将不同的散点参数传递给 Seaborn 的 lmplot 中的构面

How to pass different scatter kwargs to facets in lmplot in Seaborn

我正在尝试将第三个变量映射到 Seaborn lmplot 中的散点颜色。所以 total_bill 在 x 上,tip 在 y 上,点颜色作为 size.

的函数

它在没有启用分面时有效,但在使用 col 时失败,因为颜色数组大小与每个分面中绘制的数据大小不匹配。

这是我的代码

    import matplotlib as mpl
    import seaborn as sns
    sns.set(color_codes=True)

    # load data
    data = sns.load_dataset("tips")

    # size of data
    print len(data.index)

    ### we want to plot scatter point colour as function of variable 'size'

    # first, sort the data by 'size' so that high 'size' values are plotted
    # over the smaller sizes (so they are more visible)

    data = data.sort_values(by=['size'], ascending=True)

    scatter_kws = dict()
    cmap = mpl.cm.get_cmap(name='Blues')

    # normalise 'size' variable as float range needs to be
    # between 0 and 1 to map to a valid colour
    scatter_kws['c'] = data['size'] / data['size'].max()

    # map normalised values to colours
    scatter_kws['c'] = cmap(scatter_kws['c'].values)

    # colour array has same size as data
    print len(scatter_kws['c'])

    # this works as intended
    g = sns.lmplot(data=data, x="total_bill", y="tip", scatter_kws=scatter_kws)

上面的效果很好并产生了以下内容(还不允许包含图像,所以这里是 link):

lmplot with point colour as function of size

但是,当我将 col='sex' 添加到 lmplot(尝试下面的代码)时,问题是颜色数组的原始数据集大小大于每个方面绘制的数据大小。因此,例如 col='male' 有 157 个数据点,因此颜色数组中的前 157 个值被映射到这些点(这些甚至不是正确的)。见下文:

lmplot with point colour as function of size with col=sex

    g = sns.lmplot(data=data, x="total_bill", y="tip", col="sex", scatter_kws=scatter_kws)

理想情况下,我想将 scatter_kws 的数组传递给 lmplot,以便每个面都使用正确的颜色数组(我会在传递给 lmplot 之前计算)。但这似乎不是一个选择。

任何其他想法或解决方法仍然允许我使用 Seaborn 的 lmplot 的功能(意思是,无需重新创建 FacetGridlmplot 功能?

原则上,具有不同 colslmplot 似乎只是几个 regplot 的包装。因此,我们可以使用两个 regplots 而不是一个 lmplot,每个 sex 一个。

因此我们需要将原始数据帧分成malefemale,剩下的就很简单了。

import matplotlib.pyplot as plt
import seaborn as sns

data = sns.load_dataset("tips")

data = data.sort_values(by=['size'], ascending=True)
# make a new dataframe for males and females
male = data[data["sex"] == "Male"]
female = data[data["sex"] == "Female"]

# get normalized colors for all data
colors = data['size'].values / float(data['size'].max())
# get colors for males / females
colors_male = colors[data["sex"].values == "Male"]
colors_female = colors[data["sex"].values == "Female"]
# colors are values in [0,1] range


fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9,4))

#create regplot for males, put it to left axes
#use colors_male to color the points with Blues cmap
sns.regplot(data=male, x="total_bill", y="tip", ax=ax1, 
            scatter_kws= {"c" : colors_male, "cmap":"Blues"})
# same for females
sns.regplot(data=female, x="total_bill", y="tip", ax=ax2, 
            scatter_kws={"c" : colors_female, "cmap":"Greens"})
ax1.set_title("Males")
ax2.set_title("Females")
for ax in [ax1, ax2]:
    ax.set_xlim([0,60])
    ax.set_ylim([0,12])
plt.tight_layout()
plt.show()