torch矩阵等价和运算
torch matrix equaity sum operation
我想做一个类似于矩阵乘法的运算,除了我想检查相等性而不是乘法。我想要实现的效果类似下面这样:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]
有没有一种方法可以使用 einsum 或 pytorch 中的任何其他函数来有效地实现上述目标?
您可以使用 torch.repeat and torch.repeat_interleave:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
mask = a.repeat_interleave(3, dim=0) == b.repeat((2, 1))
torch.sum(mask, axis=1).reshape(a.shape)
# output
tensor([[3, 0, 0],
[0, 3, 0]])
您可以使用 broadcasting 来做同样的事情,例如
result = (a[:, None, :] == b[None, :, :]).sum(dim=2)
这里 None
只是介绍了一个虚拟尺寸 - 或者,您可以使用视觉效果较差的 .unsqueeze()
代替。
矩阵乘法在 einsum 表示法中是 ij,jk->ik
,所有这些运算都等同于不同级别的冗长程度:
a @ b
torch.einsum("ij,jk", a, b)
torch.einsum("ij,jk->ik", a, b)
(a[:,:,None] * b[None,:,:]).sum(1)
"乘以 i
和 k
维度并减少 j
维度"
i, j, k i, j, k
a: (2, 3) => (2, 3, None)
b: (3, 3) (None, 3, 3)
现在从这个函数分解中应该清楚乘法可以用任何二元运算代替,例如相等运算。
不幸的是,pytorch 中没有通用形式的 einsum (AFAIK) 交换乘法“out-of-the-box”。然而,einops
库基本上是深度学习框架(如 PyTorch)的包装器。
我想做一个类似于矩阵乘法的运算,除了我想检查相等性而不是乘法。我想要实现的效果类似下面这样:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]
有没有一种方法可以使用 einsum 或 pytorch 中的任何其他函数来有效地实现上述目标?
您可以使用 torch.repeat and torch.repeat_interleave:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
mask = a.repeat_interleave(3, dim=0) == b.repeat((2, 1))
torch.sum(mask, axis=1).reshape(a.shape)
# output
tensor([[3, 0, 0],
[0, 3, 0]])
您可以使用 broadcasting 来做同样的事情,例如
result = (a[:, None, :] == b[None, :, :]).sum(dim=2)
这里 None
只是介绍了一个虚拟尺寸 - 或者,您可以使用视觉效果较差的 .unsqueeze()
代替。
矩阵乘法在 einsum 表示法中是 ij,jk->ik
,所有这些运算都等同于不同级别的冗长程度:
a @ b
torch.einsum("ij,jk", a, b)
torch.einsum("ij,jk->ik", a, b)
(a[:,:,None] * b[None,:,:]).sum(1)
"乘以 i
和 k
维度并减少 j
维度"
i, j, k i, j, k
a: (2, 3) => (2, 3, None)
b: (3, 3) (None, 3, 3)
现在从这个函数分解中应该清楚乘法可以用任何二元运算代替,例如相等运算。
不幸的是,pytorch 中没有通用形式的 einsum (AFAIK) 交换乘法“out-of-the-box”。然而,einops
库基本上是深度学习框架(如 PyTorch)的包装器。