如何向量化散点矩阵运算

how to vectorize the scatter-matmul operation

我有很多矩阵 w1w2w3...wn,形状为 (k*n1k*n2k*n3...k*nn) 和 x1x2x3...xn 具有形状 (n1*mn2*mn3*m...nn*m)。
我想分别得到 w1@x1, w2@x2, w3@x3 ...

生成的矩阵是多个 k*m 矩阵,可以连接成一个形状为 (k*n)*m.

的大矩阵

一个一个地相乘会很慢。如何向量化这个操作?

注意:输入可以是一个k*(n1+n2+n3+...+nn)矩阵和一个(n1+n2+n3+...+nn)*m矩阵,我们可以使用批索引来表示这些子矩阵。

此操作与pytorch_scatter中实现的分散操作有关,因此我将其称为“scatter_matmul”。

您可以通过创建形状为 n*kx(n1+..+nn) 的大型块对角矩阵 W 来矢量化您的操作,其中 w_i 矩阵是对角线上的块.然后你可以将所有 x 矩阵垂直堆叠成形状为 (n1+..+nn)xmX 矩阵。将块对角线 W 与所有 x 矩阵的垂直堆栈相乘,X

Y = W @ X

Y 形状为 (k*n)xm 的结果,这正是您正在寻找的串联大矩阵。

如果分块对角矩阵W的形状太大,内存放不下,可以考虑制作W sparse and compute the product using torch.sparse.mm.

请看看这个link。显然 DGL 已经在做类似的事情了:https://github.com/dmlc/dgl/pull/3641