检查张量 A 行中任何前 k 个条目是否与张量 B 行中的 argmax 相等
Check equality of any top k entries in rows of tensor A against argmax in rows of tensor B
新 tensors/pytorch。
我有两个二维张量,A 和 B。
A 包含表示分配给特定索引的概率的浮点数。 B 在正确的索引中包含一个单热二进制向量。
A
tensor([[0.1, 0.4, 0.5],
[0.5, 0.4, 0.1],
[0.4, 0.5, 0.1]])
B
tensor([[0, 0, 1],
[0, 1, 0],
[0, 0, 1]])
我想找到 A 的任何前 k 个值的索引与 B 中的单热索引匹配的行数。在这种情况下,k=2。
我的尝试:
tops = torch.topk(A, 2, dim=1)
top_idx = tops.indices
top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
如果操作正确,该示例应该 return 一个张量 ([0, 1]),因为前 2 行有前 2 个匹配项,但我得到 (tensor([1]),)
作为 return.
不确定我哪里出错了。感谢您的帮助!
试试这个:
top_idx = torch.topk(A, 2, dim=1).indices
row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)
top_2_matches = torch.arange(len(row_indicator))[row_indicator]
例如:
>>> import torch
>>> A = torch.tensor([[0.1, 0.4, 0.5],
... [0.5, 0.4, 0.1],
... [0.4, 0.5, 0.1]])
>>> B = torch.tensor([[0, 0, 1],
... [0, 1, 0],
... [0, 0, 1]])
>>> tops = torch.topk(A, 2, dim=1)
>>>tops
torch.return_types.topk(
values=tensor([[0.5000, 0.4000],
[0.5000, 0.4000],
[0.5000, 0.4000]]),
indices=tensor([[2, 1],
[0, 1],
[1, 0]]))
>>> top_idx = tops.indices
>>> top_idx
tensor([[2, 1],
[0, 1],
[1, 0]])
>>> index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1)
>>> index_indicator
tensor([[ True, False],
[False, True],
[False, False]])
>>> row_indicator = index_indicator.any(dim=1)
>>> row_indicator
tensor([ True, True, False])
>>> top_2_matches = torch.arange(len(row_indicator))[row_indicator]
>>> top_2_matches
tensor([0, 1])
新 tensors/pytorch。
我有两个二维张量,A 和 B。
A 包含表示分配给特定索引的概率的浮点数。 B 在正确的索引中包含一个单热二进制向量。
A
tensor([[0.1, 0.4, 0.5],
[0.5, 0.4, 0.1],
[0.4, 0.5, 0.1]])
B
tensor([[0, 0, 1],
[0, 1, 0],
[0, 0, 1]])
我想找到 A 的任何前 k 个值的索引与 B 中的单热索引匹配的行数。在这种情况下,k=2。
我的尝试:
tops = torch.topk(A, 2, dim=1)
top_idx = tops.indices
top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
如果操作正确,该示例应该 return 一个张量 ([0, 1]),因为前 2 行有前 2 个匹配项,但我得到 (tensor([1]),)
作为 return.
不确定我哪里出错了。感谢您的帮助!
试试这个:
top_idx = torch.topk(A, 2, dim=1).indices
row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)
top_2_matches = torch.arange(len(row_indicator))[row_indicator]
例如:
>>> import torch
>>> A = torch.tensor([[0.1, 0.4, 0.5],
... [0.5, 0.4, 0.1],
... [0.4, 0.5, 0.1]])
>>> B = torch.tensor([[0, 0, 1],
... [0, 1, 0],
... [0, 0, 1]])
>>> tops = torch.topk(A, 2, dim=1)
>>>tops
torch.return_types.topk(
values=tensor([[0.5000, 0.4000],
[0.5000, 0.4000],
[0.5000, 0.4000]]),
indices=tensor([[2, 1],
[0, 1],
[1, 0]]))
>>> top_idx = tops.indices
>>> top_idx
tensor([[2, 1],
[0, 1],
[1, 0]])
>>> index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1)
>>> index_indicator
tensor([[ True, False],
[False, True],
[False, False]])
>>> row_indicator = index_indicator.any(dim=1)
>>> row_indicator
tensor([ True, True, False])
>>> top_2_matches = torch.arange(len(row_indicator))[row_indicator]
>>> top_2_matches
tensor([0, 1])