pytorch中每行的第k个值?

kth-value per row in pytorch?

给定

import torch    
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]

意思是我想用不同的k对每一列进行操作。
不是 kthvalue 也不是 topk 提供这样的 API。 是否有 矢量化 解决方法?
备注 - 第k个值不是第k个索引中的值,而是第k小的元素。 Pytorch docs

torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple (values, indices) where values is the k th smallest element of each row of the input tensor in the given dimension dim. And indices is the index location of each element found.

假设您不需要原始矩阵中的索引(如果需要,也只需对第二个 return 值使用花式索引),您可以简单地对值进行排序(默认情况下按最后一个索引)和return 适当的值如下:

def kth_smallest(tensor, indices):
    tensor_sorted, _ = torch.sort(tensor)
    return tensor_sorted[torch.arange(len(indices)), indices]

此测试用例为您提供所需的值:

tensor = torch.tensor(
    [[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)

print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]