Numpy/PyTorch 有趣的张量积

Numpy/PyTorch funny tensor product

我有一个 4 维火炬张量参数定义如下:

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

和四个带有 dims (batch_size,13) 的张量(或一个带有 dims (batch_size,4,13) 的张量)。 我想得到一个张量,其 dims (batch_size) 等于这张图片末尾的公式: [编辑:我在第一张图片中犯了一个错误,我已经更正了] 我在 torch 文档中看到了函数 tensordot,但我无法自己实现它。

每当你有一个有趣的张量积torch.einsum (or numpy.einsum)是你的朋友:

batch_size = 5
A = torch.rand(13, 13, 13, 13)
a = torch.rand(batch_size, 13)
b = torch.rand(batch_size, 13)
c = torch.rand(batch_size, 13)
d = torch.rand(batch_size, 13)
B = torch.einsum('ijkl,bi,bj,bk,bl->b', A, a, b, c, d)