Python 子图用于显示一张图

Python subplot used to show one figure

当我使用子图时,尝试使用下面的子图只绘制一张图,会报错:

AttributeError: 'AxesSubplot' object has no attribute 'flat'

fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
    ax.plot(X, Y, color='k')

任意设置子图个数如何解决?如何简单理解ax.flat

flatnumpy 数组的属性,returns 是迭代器。例如,如果你有一个像这样的二维数组:

import numpy as np
arr2d = np.arange(4).reshape(2, 2)
arr2d
# array([[0, 1],
#        [2, 3]])

提供 flat 属性作为一种方便的方式来遍历此数组,就好像它是一维数组一样:

for value in arr2d.flat:
    print(value)
# 0
# 1
# 2
# 3

你也可以用flatten方法展平数组:

arr2d.flatten()
# array([0, 1, 2, 3])

所以回到你的问题,当你指定:

  • ncols 到 1 和 nrows 到大于 1 的值或相反,你得到一维 numpy 数组中的轴,在这种情况下 flat 属性returns同一个数组。
  • ncolsnrows 都大于 1 的值,您将获得二维数组中的轴,在这种情况下 flat 属性 returns 扁平数组.
  • ncolsnrows 都为 1,您得到轴对象,它没有 flat 属性。

所以一个可能的解决方案是每次都将您的 ax 对象变成一个 numpy 数组:

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)

ax = np.array(ax)
for i, axi in enumerate(ax.flat):
    axi.plot(...)

当您使用命令 fig,ax=plt.subplots() 创建一组具有多个 rows/columns 的子图时,它 returns 一个 fig 和一个轴列表 ax . ax 列表的形状是二维的(行,列)。这就是为什么在迭代 ax 列表时需要将其展平为一维的原因。要访问特定轴,您需要 row/column 索引,例如,ax[r][c] 是第 (r+1) 行/第 (c+1) 列上的轴。这些指数是从零开始的。下面的工作代码演示了如何做到这一点。

import matplotlib.pyplot as plt
import numpy as np

nrows,ncols = 3,2
figsize = [5,9]
X = np.random.rand(6)
Y = np.random.rand(6)

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
for i, axi in enumerate(ax.flat):
    axi.plot(X, Y, color='k')
    rowid = i // ncols
    colid = i % ncols
    axi.set_title("row:"+str(rowid)+",col:"+str(colid))

# You can access the axes by row_id, col_id.
# Now let's plot on ax[row_id][col_id] of your choice
ax[0][1].plot(Y,X,color='red')    # plot 2nd line in red
ax[2][0].plot(Y,X,color='green')  # plot 2nd line in green

plt.tight_layout(True)
plt.show()

输出图:

您可以使用 ax.ravel()ax.flatten()。下面是一个简单的例子

import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 6))

# for i, ax in enumerate(ax.flatten()):
for i, ax in enumerate(ax.ravel()):
    ax.plot([1,2,3], color='k')
plt.show()

恰好有一种情况代码

fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
    ax.plot(X, Y, color='k')

无法按预期工作。这是 nrows = ncols = 1。这是因为对于单个行和列,ax 是单个子图,而不是多个子图的数组。

为了避免这个问题,并且能够在事先不知道 nrowsncols 的情况下使用相同的代码,请使用 squeeze=False 选项。这将确保 ax 始终是一个数组,因此具有 .flat 属性。为了更好的理解,不要使用与轴本身相同的名称来调用轴数组。

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False, figsize=figsize)
for i, ax in enumerate(axs.flat):
    ax.plot(X, Y, color='k')