为一批数组计算数组矩阵乘法的有效方法
Efficient way to compute an array matrix multiplication for a batch of arrays
我想并行处理以下问题。给定一个形状为 (dim1,)
的数组 w
和一个形状为 (dim1, dim2)
的矩阵 A
,我希望 A
的每一行都乘以 A
的相应元素=11=]。
这很微不足道。
但是,我想对一堆数组执行此操作 w
,最后对结果求和。因此,为了避免 for 循环,我创建了形状为 (n_samples, dim1)
的矩阵 W
,并按以下方式使用了 np.einsum
函数:
x = np.einsum('ji, ik -> jik', W, A))
r = x.sum(axis=0)
其中 x
的形状为 (n_samples, dim1, dim2)
,最终总和的形状为 (dim1, dim2)
。
我注意到 np.einsum
对于大矩阵 A
来说相当慢。有没有更有效的方法来解决这个问题?我也想尝试 np.tensordot
但也许情况并非如此。
谢谢 :-)
In [455]: W = np.arange(1,7).reshape(2,3); A = np.arange(1,13).reshape(3,4)
你的计算:
In [463]: x = np.einsum('ji, ik -> jik', W, A)
...: r = x.sum(axis=0)
In [464]: r
Out[464]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
如评论中所述,einsum
可以对 j
执行求和:
In [465]: np.einsum('ji, ik -> ik', W, A)
Out[465]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
并且由于 j
只出现在 A
中,我们可以先对 A
求和:
In [466]: np.sum(W,axis=0)[:,None]*A
Out[466]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
这不涉及乘积和,因此也不涉及矩阵乘法。
或者乘法求和:
In [475]: (W[:,:,None]*A).sum(axis=0)
Out[475]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
我想并行处理以下问题。给定一个形状为 (dim1,)
的数组 w
和一个形状为 (dim1, dim2)
的矩阵 A
,我希望 A
的每一行都乘以 A
的相应元素=11=]。
这很微不足道。
但是,我想对一堆数组执行此操作 w
,最后对结果求和。因此,为了避免 for 循环,我创建了形状为 (n_samples, dim1)
的矩阵 W
,并按以下方式使用了 np.einsum
函数:
x = np.einsum('ji, ik -> jik', W, A))
r = x.sum(axis=0)
其中 x
的形状为 (n_samples, dim1, dim2)
,最终总和的形状为 (dim1, dim2)
。
我注意到 np.einsum
对于大矩阵 A
来说相当慢。有没有更有效的方法来解决这个问题?我也想尝试 np.tensordot
但也许情况并非如此。
谢谢 :-)
In [455]: W = np.arange(1,7).reshape(2,3); A = np.arange(1,13).reshape(3,4)
你的计算:
In [463]: x = np.einsum('ji, ik -> jik', W, A)
...: r = x.sum(axis=0)
In [464]: r
Out[464]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
如评论中所述,einsum
可以对 j
执行求和:
In [465]: np.einsum('ji, ik -> ik', W, A)
Out[465]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
并且由于 j
只出现在 A
中,我们可以先对 A
求和:
In [466]: np.sum(W,axis=0)[:,None]*A
Out[466]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
这不涉及乘积和,因此也不涉及矩阵乘法。
或者乘法求和:
In [475]: (W[:,:,None]*A).sum(axis=0)
Out[475]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])