torch矩阵等价和运算

torch matrix equaity sum operation

我想做一个类似于矩阵乘法的运算,除了我想检查相等性而不是乘法。我想要实现的效果类似下面这样:

a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]

有没有一种方法可以使用 einsum 或 pytorch 中的任何其他函数来有效地实现上述目标?

您可以使用 torch.repeat and torch.repeat_interleave:

a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

mask = a.repeat_interleave(3, dim=0) == b.repeat((2, 1))

torch.sum(mask, axis=1).reshape(a.shape)

# output
tensor([[3, 0, 0],
        [0, 3, 0]])

您可以使用 broadcasting 来做同样的事情,例如

result = (a[:, None, :] == b[None, :, :]).sum(dim=2)

这里 None 只是介绍了一个虚拟尺寸 - 或者,您可以使用视觉效果较差的 .unsqueeze() 代替。

矩阵乘法在 einsum 表示法中是 ij,jk->ik,所有这些运算都等同于不同级别的冗长程度:

a @ b
torch.einsum("ij,jk", a, b)
torch.einsum("ij,jk->ik", a, b)
(a[:,:,None] * b[None,:,:]).sum(1)

"乘以 ik 维度并减少 j 维度"

    i, j, k             i,    j, k
a: (2, 3)        =>    (2,    3, None)
b:    (3, 3)           (None, 3, 3)

现在从这个函数分解中应该清楚乘法可以用任何二元运算代替,例如相等运算。

不幸的是,pytorch 中没有通用形式的 einsum (AFAIK) 交换乘法“out-of-the-box”。然而,einops 库基本上是深度学习框架(如 PyTorch)的包装器。