seaborn 拟合线成多面图

seaborn fit lines into multi-facetted plot

对于以下数据,

import pandas as pd
# sample data set
df = pd.DataFrame({
    'year': [ 5, 5, 5, 5, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10, 10, 10, 10, 10,
             15, 15, 15, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20, 20, 20, 20,
             20, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30],
    'month': [1, 1, 1, 3, 3, 6, 6, 12, 12, 1, 1, 3, 3, 3, 6, 6, 12, 12, 1, 1, 3,
              3, 6, 6, 12, 12, 1, 1, 3, 3, 3, 6, 6, 12, 12, 1, 1, 1, 3, 3, 3, 6,
              6, 12, 12],
    'mid': [0.825, 0.83, 0.85, 0.935, 0.96, 1.055, 1.12, 1.305, 1.38, 1.235, 
            1.25, 1.345, 1.34, 1.34, 1.455, 1.46, 1.635, 1.65, 1.375, 1.39,
            1.475, 1.47, 1.585, 1.59, 1.785, 1.8, 1.535, 1.54, 1.635, 1.63,
            1.62, 1.755, 1.74, 1.96, 1.95, 1.695, 1.7, 1.71, 1.805, 1.8, 1.79,
            1.925, 1.91, 2.12, 2.1],
    'source': ['wood', 'metal', 'water', 'wood', 'water', 'wood', 'water', 
               'wood', 'water', 'wood', 'water', 'wood', 'metal', 'water',
               'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood',
               'water', 'wood', 'water', 'wood', 'water', 'wood', 'water',
               'wood', 'metal', 'water', 'wood', 'water', 'wood', 'water',
               'wood', 'metal', 'water', 'wood', 'metal','water', 'wood',
               'water', 'wood', 'water']})

'year' 计算出一系列 5 scipy cubic spline 个对象

from scipy.interpolate import CubicSpline
def spline3(df, col1, col2):
    x, y = df[col1].values, df[col2].values
    cs = CubicSpline(x, y, bc_type='natural')
    return cs

avg_df = df.groupby(['year','month'])['mid'].mean().reset_index()

# indexed series of spline obj
splines = avg_df.groupby('year').apply(spline3, 'month', 'mid') 

所以我可以绘制单个 'year' 组的拟合曲线:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(6.5, 4))
x = df.groupby('year').get_group(5)['month']
y = df.groupby('year').get_group(5)['mid']
ax.plot(x, y, 'o', label='data')
ax.plot(x, splines[5](x) , label="S")

如何在下面的原始数据上拟合多面体 seaborn relplot 中的所有 5 条样条曲线?

import seaborn as sns
sns.relplot(x='month', y='mid', hue='source', data=df,
            palette=['green','orange', 'purple'],
            col='year', col_wrap=5, legend=True, height=6, aspect=.5,
            style='source')

我理解为每个 seaborn 图形添加样条线的意图的问题,并且我使用您的代码创建了它。

yy = []
xx = []
for i in [5,10,15,20,30]:
    x = df.groupby('year').get_group(i)['month']
    s = splines[i](x)
    xx.append(x)
    yy.append(s)

import seaborn as sns
g = sns.relplot(x='month', y='mid', hue='source', data=df,
            palette=['green','orange', 'purple'],
            col='year', col_wrap=5, legend=True, height=6, aspect=.5,
            style='source')

#print(g)
axes = g.fig.axes
for k,ax in enumerate(axes):
    ax.plot(xx[k],yy[k], c='C1')

plt.show()

您可以遍历子图并绘制样条线:

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import seaborn as sns
import pandas as pd
from scipy.interpolate import CubicSpline

def spline3(df, col1, col2):
    x, y = df[col1].values, df[col2].values
    return CubicSpline(x, y, bc_type='natural')

df = pd.DataFrame({'year': [5, 5, 5, 5, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10, 10, 10, 10, 10, 15, 15, 15, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30],
                   'month': [1, 1, 1, 3, 3, 6, 6, 12, 12, 1, 1, 3, 3, 3, 6, 6, 12, 12, 1, 1, 3, 3, 6, 6, 12, 12, 1, 1, 3, 3, 3, 6, 6, 12, 12, 1, 1, 1, 3, 3, 3, 6, 6, 12, 12],
                   'mid': [0.825, 0.83, 0.85, 0.935, 0.96, 1.055, 1.12, 1.305, 1.38, 1.235, 1.25, 1.345, 1.34, 1.34, 1.455, 1.46, 1.635, 1.65, 1.375, 1.39, 1.475, 1.47, 1.585, 1.59, 1.785, 1.8, 1.535, 1.54, 1.635, 1.63, 1.62, 1.755, 1.74, 1.96, 1.95, 1.695, 1.7, 1.71, 1.805, 1.8, 1.79, 1.925, 1.91, 2.12, 2.1],
                   'source': ['wood', 'metal', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'metal', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'metal', 'water', 'wood', 'water', 'wood', 'water', 'wood', 'metal', 'water', 'wood', 'metal', 'water', 'wood', 'water', 'wood', 'water']})
avg_df = df.groupby(['year', 'month'])['mid'].mean().reset_index()

# indexed series of spline obj
splines = avg_df.groupby('year').apply(spline3, 'month', 'mid')

fig, ax = plt.subplots(figsize=(6.5, 4))
x = df.groupby('year').get_group(5)['month']
y = df.groupby('year').get_group(5)['mid']
ax.plot(x, y, 'o', label='data')
ax.plot(x, splines[5](x), label="S")

g = sns.relplot(x='month', y='mid', hue='source', data=df,
                palette=['green', 'orange', 'purple'],
                col='year', col_wrap=5, legend=True, height=6, aspect=.5,
                style='source')
for ax, year, spline in zip(g.axes.flat, g.col_names, splines):
    x = df[df['year'] == year]['month']
    ax.plot(x, spline(x))
    ax.xaxis.set_major_locator(MultipleLocator(3))
plt.show()