嵌套或组合 matplotlib 图形和绘图?
Nesting or combining matplotlib figures and plots?
我有一个函数,它采用任意长度的日期、价格(浮动)和一些结果值(浮动)的 3D 数据集,并制作一组按年份划分的 seaborn 热图。伪代码如下(注意年数因数据集而异,所以我需要它任意缩放):
def makePlots(data):
split data by year
fig,axs=plt.subplots(1, numYears)
x=0
for year in years
sns.heatmap(data[year], ax = axs[x++])
return axs
这会输出一个单独的 matplotlib 图,其中每一年的热图在一行中彼此相邻,如本例所示:single plotted dataset
现在我有一个更高级别的函数,我在其中输入两个数据集(每个数据集任意年数)并让它打印每个数据集的热图以供比较。我希望它以某种方式获取 makePlots
方法制作的数字并将它们堆叠在一起,如本例所示:two plotted datasets
def compareData(data1,data2):
fig1 = makePlots(data1)
fig2 = makePlots(data2)
fig, (ax1,ax2) = plt.subplots(2,1)
ax1 = fig1
ax2 = fig2
plt.show()
现在这段代码 可以工作 ,但不是预期的那样。它打开了 3 个新图 windows,一个数据 1 正确绘制,一个数据 2 正确绘制,一个空 2 行子图。有什么方法可以将 makePlots 图嵌套在一个新的子图中,一个在另一个之上?我也试过返回 plt.gcf()
。堆栈溢出的所有其他答案都取决于将轴传递给 plot 方法,但考虑到我每个数据集有任意数量的轴(年)并且最终想比较任意数量的数据集,这似乎并不理想(不是那个反正我可以想出一个实现,因为每一行都可以有任意年数)。
我不推荐它,但您可以使用 fig.add_subplot(nrow, ncol, index)
逐步添加子图。
所以你的两个函数看起来像这样:
def compareData(data1, data2):
fig = plt.figure()
makePlots(data1, row=0, fig=fig)
makePlots(data2, row=1, fig=fig)
def makePlots(data, row, fig):
years = ... # parse data here
for ii, year in enumerate(years):
ax = fig.add_subplot(2, len(years), row * len(years) + ii + 1)
sns.heatmap(data[year], ax=ax)
这有望解决您的问题。
但是,您遇到这个问题只是因为您在同一个函数中混合了数据解析和绘图。我的建议是首先解析数据,然后将新的数据结构传递给一些绘图函数。
我有一个函数,它采用任意长度的日期、价格(浮动)和一些结果值(浮动)的 3D 数据集,并制作一组按年份划分的 seaborn 热图。伪代码如下(注意年数因数据集而异,所以我需要它任意缩放):
def makePlots(data):
split data by year
fig,axs=plt.subplots(1, numYears)
x=0
for year in years
sns.heatmap(data[year], ax = axs[x++])
return axs
这会输出一个单独的 matplotlib 图,其中每一年的热图在一行中彼此相邻,如本例所示:single plotted dataset
现在我有一个更高级别的函数,我在其中输入两个数据集(每个数据集任意年数)并让它打印每个数据集的热图以供比较。我希望它以某种方式获取 makePlots
方法制作的数字并将它们堆叠在一起,如本例所示:two plotted datasets
def compareData(data1,data2):
fig1 = makePlots(data1)
fig2 = makePlots(data2)
fig, (ax1,ax2) = plt.subplots(2,1)
ax1 = fig1
ax2 = fig2
plt.show()
现在这段代码 可以工作 ,但不是预期的那样。它打开了 3 个新图 windows,一个数据 1 正确绘制,一个数据 2 正确绘制,一个空 2 行子图。有什么方法可以将 makePlots 图嵌套在一个新的子图中,一个在另一个之上?我也试过返回 plt.gcf()
。堆栈溢出的所有其他答案都取决于将轴传递给 plot 方法,但考虑到我每个数据集有任意数量的轴(年)并且最终想比较任意数量的数据集,这似乎并不理想(不是那个反正我可以想出一个实现,因为每一行都可以有任意年数)。
我不推荐它,但您可以使用 fig.add_subplot(nrow, ncol, index)
逐步添加子图。
所以你的两个函数看起来像这样:
def compareData(data1, data2):
fig = plt.figure()
makePlots(data1, row=0, fig=fig)
makePlots(data2, row=1, fig=fig)
def makePlots(data, row, fig):
years = ... # parse data here
for ii, year in enumerate(years):
ax = fig.add_subplot(2, len(years), row * len(years) + ii + 1)
sns.heatmap(data[year], ax=ax)
这有望解决您的问题。
但是,您遇到这个问题只是因为您在同一个函数中混合了数据解析和绘图。我的建议是首先解析数据,然后将新的数据结构传递给一些绘图函数。