从矩阵 B 的每一行中减去矩阵 A 的每一行,无需循环

Subtract each row of matrix A from every row of matrix B without loops

给定两个数组,A(形状:M X C)和B(形状:N X C),有没有办法从[的每一行中减去A的每一行=13=] 不使用循环?最终输出的形状为 (M N X C)。


例子

A = np.array([[  1,   2,   3], 
              [100, 200, 300]])

B = np.array([[  10,   20,   30],
              [1000, 2000, 3000],
              [ -10,  -20,   -2]])

期望的结果(可以有一些其他形状)(已编辑):

array([[  -9,   -18,   -27],
       [-999, -1998, -2997],
       [  11,    22,     5],
       [  90,   180,   270],
       [-900, -1800, -2700],
       [ 110,   220,   302]])

Shape: 6 X 3

(循环太慢,“outer”减去每个元素而不是每行)

可以通过利用 broadcasting 来高效地完成它(不使用任何循环),例如:

In [28]: (A[:, np.newaxis] - B).reshape(-1, A.shape[1])
Out[28]: 
array([[   -9,   -18,   -27],
       [ -999, -1998, -2997],
       [   11,    22,     5],
       [   90,   180,   270],
       [ -900, -1800, -2700],
       [  110,   220,   302]])

或者,为了比 broadcasting 快一点的解决方案,我们必须使用 numexpr,例如:

In [31]: A_3D = A[:, np.newaxis]
In [32]: import numexpr as ne

# pass the expression for subtraction as a string to `evaluate` function
In [33]: ne.evaluate('A_3D - B').reshape(-1, A.shape[1])
Out[33]: 
array([[   -9,   -18,   -27],
       [ -999, -1998, -2997],
       [   11,    22,     5],
       [   90,   180,   270],
       [ -900, -1800, -2700],
       [  110,   220,   302]], dtype=int64)

另一种效率最低的方法是使用 np.repeat and np.tile 来匹配两个数组的形状。但是,请注意,这是 效率最低的 选项,因为它在尝试匹配形状时会生成 副本

In [27]: np.repeat(A, B.shape[0], 0) - np.tile(B, (A.shape[0], 1))
Out[27]: 
array([[   -9,   -18,   -27],
       [ -999, -1998, -2997],
       [   11,    22,     5],
       [   90,   180,   270],
       [ -900, -1800, -2700],
       [  110,   220,   302]])

使用 Kronecker product (numpy.kron):

>>> import numpy as np
>>> A = np.array([[  1,   2,   3], 
...               [100, 200, 300]])
>>> B = np.array([[  10,   20,   30],
...               [1000, 2000, 3000],
...               [ -10,  -20,   -2]])
>>> (m,c) = A.shape
>>> (n,c) = B.shape
>>> np.kron(A,np.ones((n,1))) - np.kron(np.ones((m,1)),B)
array([[   -9.,   -18.,   -27.],
       [ -999., -1998., -2997.],
       [   11.,    22.,     5.],
       [   90.,   180.,   270.],
       [ -900., -1800., -2700.],
       [  110.,   220.,   302.]])