Seaborn:jointplots 的子图不起作用
Seaborn: Subplot of jointplots doesn't work
我想创建一个子图,其中两个不同的联合图被水平合并,因此我使用了 plt.subplots(1, 2)
。
但是,结果有两个问题:
- 不知何故,顶部出现了两个不必要的空白图,我想删除。
- 地块目前是垂直合并的,而不是水平合并的。
如何修改我的代码来修复它?提前致谢!
import seaborn as sns
import numpy as np
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")
fig, axes = plt.subplots(1, 2)
g = sns.jointplot(ax = axes[0], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_length', sizes=(10, 200), ax=g.ax_joint)
g = sns.jointplot(ax = axes[1], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_width', sizes=(10, 200), ax=g.ax_joint)
根据你的问题,我能够首先用联合图创建每张图。我围绕将其转换为次要情节做了很多研究,并在这里发现了一个鼓舞人心的 answer,我将其应用。它出色地解决了您的问题。谢谢! @ImportanceOfBeingErnest
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
class SeabornFig2Grid():
def __init__(self, seaborngrid, fig, subplot_spec):
self.fig = fig
self.sg = seaborngrid
self.subplot = subplot_spec
if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
isinstance(self.sg, sns.axisgrid.PairGrid):
self._movegrid()
elif isinstance(self.sg, sns.axisgrid.JointGrid):
self._movejointgrid()
self._finalize()
def _movegrid(self):
""" Move PairGrid or Facetgrid """
self._resize()
n = self.sg.axes.shape[0]
m = self.sg.axes.shape[1]
self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
for i in range(n):
for j in range(m):
self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])
def _movejointgrid(self):
""" Move Jointgrid """
h= self.sg.ax_joint.get_position().height
h2= self.sg.ax_marg_x.get_position().height
r = int(np.round(h/h2))
self._resize()
self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)
self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])
def _moveaxes(self, ax, gs):
#
ax.remove()
ax.figure=self.fig
self.fig.axes.append(ax)
self.fig.add_axes(ax)
ax._subplotspec = gs
ax.set_position(gs.get_position(self.fig))
ax.set_subplotspec(gs)
def _finalize(self):
plt.close(self.sg.fig)
self.fig.canvas.mpl_connect("resize_event", self._resize)
self.fig.canvas.draw()
def _resize(self, evt=None):
self.sg.fig.set_size_inches(self.fig.get_size_inches())
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")
g0 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g0.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_length'], legend='brief')
g0.plot_marginals(sns.histplot, kde=True, color='k')
g1 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g1.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_width'], legend='brief')
g1.plot_marginals(sns.histplot, kde=True, color='k')
fig = plt.figure(figsize=(13,8))
gs = gridspec.GridSpec(1, 2)
mg0 = SeabornFig2Grid(g0, fig, gs[0])
mg1 = SeabornFig2Grid(g1, fig, gs[1])
gs.tight_layout(fig)
plt.show()
我想创建一个子图,其中两个不同的联合图被水平合并,因此我使用了 plt.subplots(1, 2)
。
但是,结果有两个问题:
- 不知何故,顶部出现了两个不必要的空白图,我想删除。
- 地块目前是垂直合并的,而不是水平合并的。
如何修改我的代码来修复它?提前致谢!
import seaborn as sns
import numpy as np
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")
fig, axes = plt.subplots(1, 2)
g = sns.jointplot(ax = axes[0], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_length', sizes=(10, 200), ax=g.ax_joint)
g = sns.jointplot(ax = axes[1], x="sepal_width", y="sepal_length", data=iris, kind="reg", color='k')
g.ax_joint.cla()
sns.scatterplot(data=iris, x='sepal_width', y='sepal_length', size='petal_width', sizes=(10, 200), ax=g.ax_joint)
根据你的问题,我能够首先用联合图创建每张图。我围绕将其转换为次要情节做了很多研究,并在这里发现了一个鼓舞人心的 answer,我将其应用。它出色地解决了您的问题。谢谢! @ImportanceOfBeingErnest
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
class SeabornFig2Grid():
def __init__(self, seaborngrid, fig, subplot_spec):
self.fig = fig
self.sg = seaborngrid
self.subplot = subplot_spec
if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
isinstance(self.sg, sns.axisgrid.PairGrid):
self._movegrid()
elif isinstance(self.sg, sns.axisgrid.JointGrid):
self._movejointgrid()
self._finalize()
def _movegrid(self):
""" Move PairGrid or Facetgrid """
self._resize()
n = self.sg.axes.shape[0]
m = self.sg.axes.shape[1]
self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
for i in range(n):
for j in range(m):
self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])
def _movejointgrid(self):
""" Move Jointgrid """
h= self.sg.ax_joint.get_position().height
h2= self.sg.ax_marg_x.get_position().height
r = int(np.round(h/h2))
self._resize()
self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)
self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])
def _moveaxes(self, ax, gs):
#
ax.remove()
ax.figure=self.fig
self.fig.axes.append(ax)
self.fig.add_axes(ax)
ax._subplotspec = gs
ax.set_position(gs.get_position(self.fig))
ax.set_subplotspec(gs)
def _finalize(self):
plt.close(self.sg.fig)
self.fig.canvas.mpl_connect("resize_event", self._resize)
self.fig.canvas.draw()
def _resize(self, evt=None):
self.sg.fig.set_size_inches(self.fig.get_size_inches())
sns.set(style="darkgrid")
iris = sns.load_dataset("iris")
g0 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g0.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_length'], legend='brief')
g0.plot_marginals(sns.histplot, kde=True, color='k')
g1 = sns.JointGrid(x="sepal_width", y="sepal_length", data=iris)
g1.plot_joint(sns.scatterplot, sizes=(10, 200), size=iris['petal_width'], legend='brief')
g1.plot_marginals(sns.histplot, kde=True, color='k')
fig = plt.figure(figsize=(13,8))
gs = gridspec.GridSpec(1, 2)
mg0 = SeabornFig2Grid(g0, fig, gs[0])
mg1 = SeabornFig2Grid(g1, fig, gs[1])
gs.tight_layout(fig)
plt.show()