将 matplotlib 3D 图形的框架更改为 x、y 和 z 箭头

Change a matplotlib 3D figure's frames into x,y and z arrows

是否可以通过将箭头叠加在 x、y 和 z 轴的顶部来将图形的箭头更改为箭头,以创建轴是箭头的错觉,或者直接将框架的设置更改为 Matplot lib framing 为了在 3D 图上获得相同的结果,用箭头显示 (x,y,z)?

正在转动这个

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# generate sample points and straight line
z = np.repeat(0, 100) 
x = np.repeat(1.0, 100) 
y = np.linspace(start=3.0, stop=6.0, num=100) 
ax.plot(x, y, z, c='red') # draw straight line
ax.view_init(45, -150) # angle to show

# set axes limits and labels
ax.set_xlabel(r"$x$"); ax.set_ylabel(r"$y$"); ax.set_zlabel(r"$z$")
ax.set_xlim(0,1.1) ;ax.set_ylim(6,3) ;ax.set_zlim(0,1.75)

#  Remove tick marks
ax.set_xticks([0,0.25,0.5,0.75,1]) ; ax.set_xticklabels(['0','1','2','4','T'])
ax.set_yticks([6.0,5.5,5,4.5,4.0,3.5,3]) ; ax.set_yticklabels(["","","","","","",""])
ax.set_zticks([1.75,1.25,0.75,0.25]) ax.set_zticklabels(['','','',''])

# change background colour to white
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))

#plt.savefig("sample.png", type="png",dbi=400) # save image 
plt.tight_layout()
plt.show()

变成这样的东西:

我通常不使用 3D 图形,我做了很多研究来回答你的问题。这是我找到的 a great approach。我创建了一个新的 Arrow 3D class 并实现了它。在您的代码中,我添加了 class 并向 x、y 和 z 轴添加了箭头。我手动移动了它们的位置以使其在轴上对齐。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.proj3d import proj_transform
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

class Arrow3D(FancyArrowPatch):

    def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._xyz = (x, y, z)
        self._dxdydz = (dx, dy, dz)

    def draw(self, renderer):
        x1, y1, z1 = self._xyz
        dx, dy, dz = self._dxdydz
        x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)

        xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)

setattr(Axes3D, 'arrow3D', _arrow3D)

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')

# generate sample points and straight line
z = np.repeat(0, 100) 
x = np.repeat(1.0, 100) 
y = np.linspace(start=3.0, stop=6.0, num=100) 
ax.plot(x, y, z, c='red') # draw straight line
ax.view_init(45, -150) # angle to show

# set axes limits and labels
ax.set_xlabel(r"$x$"); ax.set_ylabel(r"$y$"); ax.set_zlabel(r"$z$")
ax.set_xlim(0,1.1) ;ax.set_ylim(6,3) ;ax.set_zlim(0,1.75)

#  Remove tick marks
ax.set_xticks([0,0.25,0.5,0.75,1])
ax.set_xticklabels(['0','1','2','4','T'])
ax.set_yticks([6.0,5.5,5,4.5,4.0,3.5,3])
ax.set_yticklabels(["","","","","","",""])
ax.set_zticks([1.75,1.25,0.75,0.25])
ax.set_zticklabels(['','','',''])

# change background colour to white
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))

xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()
zlim = plt.gca().get_zlim()
# print(xlim,ylim,zlim)
# (0.0, 1.1) (6.0, 3.0) (0.0, 1.75)
ax.arrow3D(-0.03, ylim[0]+0.06, 0, xlim[1]+0.05, 0, 0, mutation_scale=20, arrowstyle='<|-|>',fc='k') # x axis
ax.arrow3D(-0.03, ylim[1], 0, 0, ylim[1]+0.1, 0, mutation_scale=20, arrowstyle='<|-|>', fc='k') # y axis
ax.arrow3D(-0.05, ylim[1], 0, 0, 0, zlim[1]+0.1, mutation_scale=20, arrowstyle='<|-|>', fc='k') # z axis

ax.text2D(0.05, 0.65,r'$\mathcal{Z}$', fontsize=18, ha='center', transform=ax.transAxes)
ax.text2D(0.60, -0.03,r'$\mathcal{Y}$', fontsize=18, ha='center', transform=ax.transAxes)
ax.text2D(0.95, 0.40,r'$\mathcal{X}$', fontsize=18, ha='center', transform=ax.transAxes)

plt.tick_params(axis='both', color='white')
#plt.savefig("sample.png", type="png",dbi=400) # save image 
# plt.tight_layout()
plt.show()

实现它的一些解决方法:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# copied from your code
fig = plt.Figure()
ax = plt.subplot(111, projection='3d')

z = np.repeat(0, 100) 
x = np.repeat(1.0, 100) 
y = np.linspace(start=3.0, stop=6.0, num=100)
ax.plot(x, y, z, c='red') # draw straight line
ax.view_init(45, -150) # angle to show
ax.set_xlim(0,1.1)
ax.set_ylim(6,3)
ax.set_zlim(0,1.75)

# my code starts here
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
zmin, zmax = ax.get_zlim()

ax.quiver3D(xmin, ymin, zmin, (xmax-xmin), 0, 0, length=1, arrow_length_ratio=0.1, colors='k', linewidth=3)
ax.quiver3D(xmin, ymin, zmin, 0, (ymax-ymin), 0, length=1, arrow_length_ratio=0.1, colors='b', linewidth=3)
ax.quiver3D(xmin, ymax, zmin, 0, (ymin-ymax), 0, length=1, arrow_length_ratio=0.1, colors='b', linewidth=3)
ax.quiver3D(xmin, ymax, zmin, 0, 0, (zmax-zmin), length=1, arrow_length_ratio=0.1, colors='k', linewidth=3)

ax.quiver3D(xmax, ymin, zmin, 0, (ymax-ymin), 0, length=1, arrow_length_ratio=0, colors='k', linewidth=1, alpha=0.5)
ax.quiver3D(xmin, ymax, zmin, (xmax-xmin), 0, 0, length=1, arrow_length_ratio=0, colors='k', linewidth=1, alpha=0.5)
ax.quiver3D(xmax, ymax, zmin, 0, 0, (zmax-zmin), length=1, arrow_length_ratio=0, colors='k', linewidth=1, alpha=0.5)

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_zlim(zmin, zmax)

ax.set_xticks([xmax])
ax.set_yticks([ymin])
ax.set_zticks([zmax])

ax.set_xticklabels([r'$\mathcal{X}$'])
ax.set_yticklabels([r'$\mathcal{Y}$'])
ax.set_zticklabels([r'$\mathcal{Z}$'])
ax.grid(None)

for axis in [ax.w_xaxis, ax.w_yaxis, ax.w_zaxis]:
    axis.line.set_linewidth(0.01)
ax.tick_params(axis='x', colors='w', pad=-5, labelcolor='k', tick1On=False, tick2On=False)
ax.tick_params(axis='y', colors='w', pad=-5, labelcolor='k', tick1On=False, tick2On=False)
ax.tick_params(axis='z', colors='w', pad=-5, labelcolor='k', tick1On=False, tick2On=False)

输出为:

我会尽量总结一下代码:

  • xyz标签可以使用标准ax.set_xlabel('x-axis')等单独设置。
  • 以上代码也适用于子图。例如,如果您有 ax = plt.subplot(121, projection='3d') 而不是固定的 111 位置。
  • y轴的颜色当然可以改成黑色。我特意放了蓝色,这样你可以更容易地在代码中找到它。

有什么不明白的可以留言,我会尽量解释的。