检查张量 A 行中任何前 k 个条目是否与张量 B 行中的 argmax 相等

Check equality of any top k entries in rows of tensor A against argmax in rows of tensor B

新 tensors/pytorch。

我有两个二维张量,A 和 B。

A 包含表示分配给特定索引的概率的浮点数。 B 在正确的索引中包含一个单热二进制向量。

A
tensor([[0.1, 0.4, 0.5],
        [0.5, 0.4, 0.1],
        [0.4, 0.5, 0.1]])

B
tensor([[0, 0, 1],
        [0, 1, 0],
        [0, 0, 1]])

我想找到 A 的任何前 k 个值的索引与 B 中的单热索引匹配的行数。在这种情况下,k=2。

我的尝试:

tops = torch.topk(A, 2, dim=1)

top_idx = tops.indices

top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
      

如果操作正确,该示例应该 return 一个张量 ([0, 1]),因为前 2 行有前 2 个匹配项,但我得到 (tensor([1]),) 作为 return.

不确定我哪里出错了。感谢您的帮助!

试试这个:

top_idx = torch.topk(A, 2, dim=1).indices

row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)

top_2_matches = torch.arange(len(row_indicator))[row_indicator]

例如:

>>> import torch
>>> A = torch.tensor([[0.1, 0.4, 0.5],
...                   [0.5, 0.4, 0.1],
...                   [0.4, 0.5, 0.1]])
>>> B = torch.tensor([[0, 0, 1],
...                   [0, 1, 0],
...                   [0, 0, 1]])
>>> tops = torch.topk(A, 2, dim=1)
>>>tops
torch.return_types.topk(
values=tensor([[0.5000, 0.4000],
               [0.5000, 0.4000],
               [0.5000, 0.4000]]),
indices=tensor([[2, 1],
                [0, 1],
                [1, 0]]))
>>> top_idx = tops.indices
>>> top_idx
tensor([[2, 1],
        [0, 1],
        [1, 0]])
>>> index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1)
>>> index_indicator
tensor([[ True, False],
        [False,  True],
        [False, False]])
>>> row_indicator = index_indicator.any(dim=1)
>>> row_indicator
tensor([ True,  True, False])
>>> top_2_matches = torch.arange(len(row_indicator))[row_indicator]
>>> top_2_matches
tensor([0, 1])