如何使用其索引(2D)获取topk的值?
How to get topk's values with its indices (2D)?
我有两个 3D 张量,我想用一个的前 k 个索引得到另一个前 k 个。
例如下面的张量
a = torch.tensor([[[1], [2], [3]],
[[4], [5], [6]]])
b = torch.tensor([[[7,1], [8,2], [9,3]],
[[10,4],[11,5],[12,6]]])
pytorch 的 topk 函数会给我以下内容。
top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor: tensor([[[3], [2]],
# [[6], [5]]])
# indices: tensor([[[2], [1]],
# [[2], [1]]])
但是我想用a的结果映射到b。
# use indices to do something for b, get torch.tensor([[[8,2], [9,3]],
# [[11,5],[12,6]]])
在这种情况下,我不知道b的真实值,所以我不能用topk来b。
换句话说,我想得到一个函数 foo_slice 如下:
top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor == foo_slice(a, indices)
是否有任何方法可以使用 pytorch 实现此目的?
谢谢!
您正在寻找的解决方案是here
所以针对您的问题的基于代码的解决方案如下
#inputs are changed in order from the above ques
a = torch.tensor([[[1], [2], [3]],
[[5], [6], [4]]])
b = torch.tensor([[[7,1], [8,2], [9,3]],
[[11,5],[12,6],[10,4]]])
top_tensor, indices = torch.topk(a, 2, dim=1)
v = [indices.view(-1,2)[i] for i in range(0,indices.shape[1])]
new_tensor = []
for i,f in enumerate(v):
new_tensor.append(torch.index_select(b[i], 0, f))
print(new_tensor ) #[tensor([[9, 3],
# [8, 2]]),
#tensor([[12, 6],
# [11, 5]])]
我有两个 3D 张量,我想用一个的前 k 个索引得到另一个前 k 个。
例如下面的张量
a = torch.tensor([[[1], [2], [3]],
[[4], [5], [6]]])
b = torch.tensor([[[7,1], [8,2], [9,3]],
[[10,4],[11,5],[12,6]]])
pytorch 的 topk 函数会给我以下内容。
top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor: tensor([[[3], [2]],
# [[6], [5]]])
# indices: tensor([[[2], [1]],
# [[2], [1]]])
但是我想用a的结果映射到b。
# use indices to do something for b, get torch.tensor([[[8,2], [9,3]],
# [[11,5],[12,6]]])
在这种情况下,我不知道b的真实值,所以我不能用topk来b。
换句话说,我想得到一个函数 foo_slice 如下:
top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor == foo_slice(a, indices)
是否有任何方法可以使用 pytorch 实现此目的?
谢谢!
您正在寻找的解决方案是here
所以针对您的问题的基于代码的解决方案如下
#inputs are changed in order from the above ques
a = torch.tensor([[[1], [2], [3]],
[[5], [6], [4]]])
b = torch.tensor([[[7,1], [8,2], [9,3]],
[[11,5],[12,6],[10,4]]])
top_tensor, indices = torch.topk(a, 2, dim=1)
v = [indices.view(-1,2)[i] for i in range(0,indices.shape[1])]
new_tensor = []
for i,f in enumerate(v):
new_tensor.append(torch.index_select(b[i], 0, f))
print(new_tensor ) #[tensor([[9, 3],
# [8, 2]]),
#tensor([[12, 6],
# [11, 5]])]