使用 matplotlib 在网格上绘制数据框中的散点图

Plotting scatterplots from a dataframe on a grid with matplotlib

有没有办法用来自数据框的所有列的散点图制作网格,其中 Y 是数据框的列之一?

我可以为此在 matplotlibseaborn 上执行 for 循环(请参阅下面的代码),但我无法让它们显示在网格上。

我希望将它们显示在网格可视化中,以便于比较它们。

这是我能做的:

for col in boston_df:
    plt.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

for col in boston_df:
    sns.regplot(x=boston_df[col], y=boston_df["MEDV"])
    plt.show()

现在,如果我尝试创建一个子图,并像这样在我的循环中使用 ax.scatter()

fig, ax = plt.subplots(3, 5,figsize=(16,6))
for col in boston_df:
    ax.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

它给我错误 AttributeError: 'numpy.ndarray' object has no attribute 'scatter'

如果能找到像这样简单的解决方案就好了:

df.hist(figsize=(18,10), density=True, label=df.columns)
plt.show()

考虑使用 pandas 的 ax 参数 DataFrame.plot and seaborn's regplot:

fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(boston_df.columns[1:]):
     #boston_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=boston_df[col], y=boston_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)      # TO ACCOMMODATE TITLE

plt.show()

用随机数据进行演示:

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,6), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 'COL5'])

### PLOT BUILD
fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()

对于跨多列的多行,将分配调整为 ax 这是一个使用索引的 numpy 数组:ax[row_idx, col_idx].

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,14), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 
                                    'COL5', 'COL6', 'COL7', 'COL8', 'COl9', 
                                    'COL10', 'COL11', 'COL12', 'COL13'])

### PLOT BUILD
fig, ax = plt.subplots(2, 7, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     if i <= 6:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[0,i])
     else:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[1,i-7])     

ax[1,6].axis('off')                  # HIDES AXES ON LAST ROW AND COL

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()