理解 PyTorch 的 Einsum 函数的一个例子

Understanding an example of PyTorch's Einsum function

我正在研究一些代码,我遇到了一个我不理解的 PyTorch 的 einsum 函数的用法。文档是 here.

片段看起来像(从原来的稍微修改):

import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])

所以符号是 n=64,c=64,u=25,v=25,t=64。

我不太确定发生了什么。我认为对于 t 中的每个 25 维向量(其中 64 个),每个向量都与 u=25 个元素大小为 25 的向量相乘,然后将结果相加,或者更确切地说,是 25 维向量的 25 个点积?

任何见解表示赞赏。

基本上,您可以将其视为对特定维度进行点积,然后重新组织其余维度。

为简单起见,让我们忽略批处理维度nc(因为它们在ncuv,nctv->nctu前后是一致的),并讨论:

import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])

请注意,v 在 einsum 之后消失,这意味着 v 是被求和的维度,而 tu 不是。可以这样理解:x25个25维向量的集合; y64 个 25 维向量的集合。计算 y 中的第 t 个向量和 x 中的第 u 个向量的点积,并将其放入第 t 行和u-第 result 列。

你也可以改写成一个数学方程式:

result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u

注意两点:

  • 求和超过求和模式中消失的索引nctu,ncuv->nctv
  • 出现在模式右侧的索引是结果张量的索引