比较张量中相等元素的数量

Compare Number of Equal Elements in Tensors

我有两个 1000 * 1 维的张量。我想检查这两个张量中 1000 个元素中有多少个元素相等。我想我应该能够像 Numpy 一样在一行中做到这一点,但找不到类似的功能。

类似

equal_count = len((tensor_1.flatten() == tensor_2.flatten()).nonzero().flatten())

应该可以。

您可以只使用 == 运算符来检查是否相等,然后对结果张量求和:

# Import torch and create dummy tensors
>>> import torch
>>> A = torch.randint(2, (10,))
>>> A
tensor([0, 0, 0, 1, 0, 1, 0, 0, 1, 1])
>>> B = torch.randint(2, (10,))
>>> B
tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 0])

# Checking for number of equal values
>>> (A == B).sum()
tensor(3)

编辑:

torch.eq 产生相同的结果。因此,如果您出于某种原因更喜欢:

>>> torch.eq(A, B).sum()
tensor(3)