PyTorch 广播 4D 和 2D 矩阵的乘法?

PyTorch broadcast multiplication of 4D and 2D matrix?

如何通过广播将这两个矩阵相乘?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

输出应该是:

(10, 120, 180, 64) == (N, H, W, Y)

我假设 x 是某种批量示例,w 矩阵是相应的权重。在这种情况下,你可以简单地做:

out = x @ w.T

这是一个张量乘法,而不是逐元素乘法。您不能进行逐元素乘法来获得这样的形状,并且此操作没有意义。您所能做的就是以某种方式 unsqueeze 两个矩阵都可以广播它们的形状,并对您出于某种原因不想要的维度应用一些操作,如下所示:

x : torch.Size([10, 120, 180, 30, 1])
W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well

在这样 unsqueezing 之后,您可以沿着第三个 dimx*wsummean 以获得所需的形状。

为清楚起见,两种方式并不等同。