理解 PyTorch 的 Einsum 函数的一个例子
Understanding an example of PyTorch's Einsum function
我正在研究一些代码,我遇到了一个我不理解的 PyTorch 的 einsum 函数的用法。文档是 here.
片段看起来像(从原来的稍微修改):
import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])
所以符号是 n=64,c=64,u=25,v=25,t=64。
我不太确定发生了什么。我认为对于 t 中的每个 25 维向量(其中 64 个),每个向量都与 u=25 个元素大小为 25 的向量相乘,然后将结果相加,或者更确切地说,是 25 维向量的 25 个点积?
任何见解表示赞赏。
基本上,您可以将其视为对特定维度进行点积,然后重新组织其余维度。
为简单起见,让我们忽略批处理维度n
和c
(因为它们在ncuv,nctv->nctu
前后是一致的),并讨论:
import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])
请注意,v
在 einsum 之后消失,这意味着 v
是被求和的维度,而 t
和 u
不是。可以这样理解:x
是25
个25维向量的集合; y
是 64
个 25 维向量的集合。计算 y
中的第 t
个向量和 x
中的第 u
个向量的点积,并将其放入第 t
行和u
-第 result
列。
你也可以改写成一个数学方程式:
result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u
注意两点:
- 求和超过求和模式中消失的索引
nctu,ncuv->nctv
- 出现在模式右侧的索引是结果张量的索引
我正在研究一些代码,我遇到了一个我不理解的 PyTorch 的 einsum 函数的用法。文档是 here.
片段看起来像(从原来的稍微修改):
import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])
所以符号是 n=64,c=64,u=25,v=25,t=64。
我不太确定发生了什么。我认为对于 t 中的每个 25 维向量(其中 64 个),每个向量都与 u=25 个元素大小为 25 的向量相乘,然后将结果相加,或者更确切地说,是 25 维向量的 25 个点积?
任何见解表示赞赏。
基本上,您可以将其视为对特定维度进行点积,然后重新组织其余维度。
为简单起见,让我们忽略批处理维度n
和c
(因为它们在ncuv,nctv->nctu
前后是一致的),并讨论:
import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])
请注意,v
在 einsum 之后消失,这意味着 v
是被求和的维度,而 t
和 u
不是。可以这样理解:x
是25
个25维向量的集合; y
是 64
个 25 维向量的集合。计算 y
中的第 t
个向量和 x
中的第 u
个向量的点积,并将其放入第 t
行和u
-第 result
列。
你也可以改写成一个数学方程式:
result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u
注意两点:
- 求和超过求和模式中消失的索引
nctu,ncuv->nctv
- 出现在模式右侧的索引是结果张量的索引