在 PyTorch 中计算两个相同大小的方阵的逐行点积的有效方法

Efficient method to compute the row-wise dot product of two square matrices of the same size in PyTorch

假设我有两个大小相同的方阵A、B

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[1, 1], [1, 1]])

我想要一个由行点积组成的结果张量,比方说

tensor([3, 7])  # i.e. (1*1 + 2*1, 3*1 + 4*1) 

在 PyTorch 中实现这一目标的有效方法是什么?

import torch
import numpy as np

def row_wise_product(A, B):
    num_rows, num_cols = A.shape[0], A.shape[1]
    prod = torch.bmm(A.view(num_rows, 1, num_cols), B.view(num_rows, num_cols, 1))
    return prod

A = torch.tensor(np.array([[1, 2], [3, 4]]))
B = torch.tensor(np.array([[1, 1], [1, 1]]))
C = row_wise_product(A, B)

正如您所说,您可以使用 torch.bmm 但您首先需要广播您的输入:

>>> torch.bmm(A[..., None, :], B[..., None])
tensor([[[3]],

        [[7]]])

或者您可以使用 torch.einsum:

>>> torch.einsum('ij,ij->i', A, B)
tensor([3, 7])