PyTorch 中有矩阵左除的函数吗?

Is there a function in PyTorch for matrix left division?

MATLAB 具有反斜杠“\”运算符。 SciPy 有“lsqr”。 PyTorch 是否有求解线性方程组的等效运算符?

具体来说,我需要为 A 求解 A*X=B 的矩阵方程,并且我需要 autograd 才能通过运算反向传播错误。

Python 中没有 \ 运算符。最接近的是 Scipy 的实现:scipy.sparse.linalg.lsqr.

您可以使用

  • torch.solve 求解形状为 AX=B

    的线性方程
  • torch.lsrsq

    • 求解最小二乘问题min ||AX-B||_2(如果A.size(0) >= A.size(1)

    • 解决最小范数问题 min ||X||_2 使得 AX=B(如果 A.size(0) < A.size(1)


要求解 XA=B,您将使用转置矩阵:

def lsqrt(A, B):
    XT, _ = torch.solve(B.T, A.T)
    return XT.T