如何在 PyTorch 中进行批量点积?

How to do batched dot product in PyTorch?

我有一个大小为 [B, N, 3] 的输入张量和一个大小为 [N, 3] 的测试张量。我想应用两个张量的点积,基本上得到 [B, N] 。这真的可能吗?

是的,有可能:

a = torch.randn(5, 4, 3)
b = torch.randn(4, 3)

c = torch.einsum('ijk,jk->ij', a, b) # torch.Size([5, 4])

另一种选择:

a = torch.randn(5, 4, 3)
b = torch.randn(4, 3)

c = (a * b[None, ...]).sum(dim=-1) # torch.Size([5, 4])