比较张量中相等元素的数量
Compare Number of Equal Elements in Tensors
我有两个 1000 * 1 维的张量。我想检查这两个张量中 1000 个元素中有多少个元素相等。我想我应该能够像 Numpy 一样在一行中做到这一点,但找不到类似的功能。
类似
equal_count = len((tensor_1.flatten() == tensor_2.flatten()).nonzero().flatten())
应该可以。
您可以只使用 ==
运算符来检查是否相等,然后对结果张量求和:
# Import torch and create dummy tensors
>>> import torch
>>> A = torch.randint(2, (10,))
>>> A
tensor([0, 0, 0, 1, 0, 1, 0, 0, 1, 1])
>>> B = torch.randint(2, (10,))
>>> B
tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 0])
# Checking for number of equal values
>>> (A == B).sum()
tensor(3)
编辑:
torch.eq
产生相同的结果。因此,如果您出于某种原因更喜欢:
>>> torch.eq(A, B).sum()
tensor(3)
我有两个 1000 * 1 维的张量。我想检查这两个张量中 1000 个元素中有多少个元素相等。我想我应该能够像 Numpy 一样在一行中做到这一点,但找不到类似的功能。
类似
equal_count = len((tensor_1.flatten() == tensor_2.flatten()).nonzero().flatten())
应该可以。
您可以只使用 ==
运算符来检查是否相等,然后对结果张量求和:
# Import torch and create dummy tensors
>>> import torch
>>> A = torch.randint(2, (10,))
>>> A
tensor([0, 0, 0, 1, 0, 1, 0, 0, 1, 1])
>>> B = torch.randint(2, (10,))
>>> B
tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 0])
# Checking for number of equal values
>>> (A == B).sum()
tensor(3)
编辑:
torch.eq
产生相同的结果。因此,如果您出于某种原因更喜欢:
>>> torch.eq(A, B).sum()
tensor(3)