pytorch中每行的第k个值?
kth-value per row in pytorch?
给定
import torch
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]
意思是我想用不同的k对每一列进行操作。
不是 kthvalue
也不是 topk 提供这样的 API。
是否有 矢量化 解决方法?
备注 - 第k个值不是第k个索引中的值,而是第k小的元素。 Pytorch docs
torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple (values, indices) where values is the k th smallest element of each row of the input tensor in the given dimension dim. And indices is the index location of each element found.
假设您不需要原始矩阵中的索引(如果需要,也只需对第二个 return 值使用花式索引),您可以简单地对值进行排序(默认情况下按最后一个索引)和return 适当的值如下:
def kth_smallest(tensor, indices):
tensor_sorted, _ = torch.sort(tensor)
return tensor_sorted[torch.arange(len(indices)), indices]
此测试用例为您提供所需的值:
tensor = torch.tensor(
[[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)
print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]
给定
import torch
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]
意思是我想用不同的k对每一列进行操作。
不是 kthvalue
也不是 topk 提供这样的 API。
是否有 矢量化 解决方法?
备注 - 第k个值不是第k个索引中的值,而是第k小的元素。 Pytorch docs
torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns a namedtuple (values, indices) where values is the k th smallest element of each row of the input tensor in the given dimension dim. And indices is the index location of each element found.
假设您不需要原始矩阵中的索引(如果需要,也只需对第二个 return 值使用花式索引),您可以简单地对值进行排序(默认情况下按最后一个索引)和return 适当的值如下:
def kth_smallest(tensor, indices):
tensor_sorted, _ = torch.sort(tensor)
return tensor_sorted[torch.arange(len(indices)), indices]
此测试用例为您提供所需的值:
tensor = torch.tensor(
[[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)
print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]