Pytorch:将张量的前 10% 值设置为零

Pytorch: set top 10% values of the tensor to zero

我有正态分布的 Pytorch 二维张量。 有没有一种快速的方法可以使用 Python 使该张量的前 10% 最大值无效?

我在这里看到两种可能的方法:

  1. 将张量展平为 1d 并对其进行排序
  2. 使用一些本机 Python 运算符(for-if)的非矢量化方式

但这些看起来都不够快。

那么,将张量的 X 最大值设置为零的最快方法是什么?

好吧,Pytorch 似乎有一个有用的运算符 torch.quantile() 在这里有很大帮助。

解(一维张量):

import torch

x = torch.randn(100)
y = torch.tensor(0.) #new value to assign 
split_val = torch.quantile(x, 0.9)
x = torch.where(x < split_val, x, y)