如何绘制具有多个轴的 pandas DataFrame,每个轴呈现多个列?

How to plot a pandas DataFrame with multiple axes each rendering multiple columns?

我正在寻找一个可以按如下方式工作的函数:

import pandas as pd

def plot_df(df: pd.DataFrame, x_column: str, columns: List[List[str]]):
  """Plot DataFrame using `x_column` on the x-axis and `len(columns)` different
  y-axes where the axis numbered `i` is calibrated to render the columns in `columns[i]`.

  Important: only 1 legend exists for the plot
  Important: each column has a distinct color
    If you wonder what colors axes should have, they can assume one of the line colors and just have a label associated (e.g., one axis for price, another for returns, another for growth)
"""

例如,对于包含列 time, price1, price2, returns, growth 的 DataFrame,您可以这样称呼它:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

这将导致图表具有:

我已经查看了一些不适用于此的解决方案。

可能的解决方案#1:

https://matplotlib.org/stable/gallery/ticks_and_spines/multiple_yaxis_with_spines.html

在这个例子中,每个轴只能容纳一列,所以是错误的。特别是在下面的代码中,每个轴都有一个系列:

p1, = ax.plot([0, 1, 2], [0, 1, 2], "b-", label="Density")
p2, = twin1.plot([0, 1, 2], [0, 3, 2], "r-", label="Temperature")
p3, = twin2.plot([0, 1, 2], [50, 30, 15], "g-", label="Velocity")

如果您向该轴添加另一个图,相同的颜色最终会重复:

此外,此版本不使用数据框的内置plot()功能。

可能的解决方案#2:

PANDAS plot multiple Y axes

在此示例中,每个轴也只能容纳数据框中的一列。

可能的解决方案#3:

尝试通过将 df.A 更改为 df[['A', 'B']] 来调整解决方案 2,但这并不能奏效,因为它会导致这两列共享相同的轴颜色以及弹出多个图例.

所以 - 询问 pandas/matplotlib 专家是否可以想出如何克服这个问题!

我假设您正在使用这样的数据框:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

那么一个可能的函数可能是:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)
    line_styles = ["-", "--", "-.", ":"]

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    for i, col in enumerate(columns):
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(i)[:3])
            handles.append(p)
        else:
            for j, sub_col in enumerate(col):
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(i)[:3], linestyle = line_styles[j])
                handles.append(p)

    ax.legend(handles = handles, frameon = True)

    for i, ax in enumerate(axes):
        ax.tick_params(axis = 'y', colors = cmap(i)[:3])
        if i == 0:
            ax.spines['left'].set_color(cmap(i)[:3])
            ax.spines['right'].set_visible(False)
        else:
            ax.spines['left'].set_visible(False)
            ax.spines['right'].set_color(cmap(i)[:3])

    plt.tight_layout()

    plt.show()

如果你调用上面的函数:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

那么你将得到:

注释

  1. 由于 price1price2 共享相同的 y 轴,它们必须共享相同的颜色,所以我必须使用不同的 linestyle 才能区分它们。
  2. columns 列表的第一个元素(在本例中为 ['price1', 'price2'])始终绘制在左轴上,其他元素绘制在右轴上。
  3. 如果您想设置轴限制和标签,那么您应该将这些作为附加参数传递给 plot_df

我假设您正在使用这样的数据框:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

那么一个可能的函数可能是:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    j = 0
    for i, col in enumerate(columns):
        ylabel = []
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(j)[:3])
            ylabel.append(col[0])
            handles.append(p)
            j += 1
        else:

            for sub_col in col:
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(j)[:3])
                ylabel.append(sub_col)
                handles.append(p)
                j += 1
        axes[i].set_ylabel(', '.join(ylabel))

    ax.legend(handles = handles, frameon = True)

    plt.tight_layout()

    plt.show()

如果你调用上面的函数:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

那么你将得到:

注释

列列表的第一个元素(在本例中为 ['price1', 'price2'])始终绘制在左轴上,其他元素绘制在右轴上。

您可以将轴从 df 链接到 df。

import pandas as pd
import numpy as np

创建数据并将其放入 df。

x=np.arange(0,2*np.pi,0.01)
b=np.sin(x)
c=np.cos(x)*10
d=np.sin(x+np.pi/4)*100
e=np.sin(x+np.pi/3)*50
df = pd.DataFrame({'x':x,'y1':b,'y2':c,'y3':d,'y4':e})

定义第一个图和后续轴

ax1 = df.plot(x='x',y='y1',legend=None,color='black',figsize=(10,8))
ax2 = ax1.twinx()
ax2.tick_params(axis='y', labelcolor='r')

ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes',1.15))
ax3.tick_params(axis='y', labelcolor='g')

ax4=ax1.twinx()
ax4.spines['right'].set_position(('axes',1.30))
ax4.tick_params(axis='y', labelcolor='b')

想加多少就加多少...

绘制余数。

df.plot(x='x',y='y2',ax=ax2,color='r',legend=None)
df.plot(x='x',y='y3',ax=ax3,color='g',legend=None)
df.plot(x='x',y='y4',ax=ax4,color='b',legend=None)

Results: