Pytorch:将张量的前 10% 值设置为零
Pytorch: set top 10% values of the tensor to zero
我有正态分布的 Pytorch 二维张量。
有没有一种快速的方法可以使用 Python 使该张量的前 10% 最大值无效?
我在这里看到两种可能的方法:
- 将张量展平为 1d 并对其进行排序
- 使用一些本机 Python 运算符(for-if)的非矢量化方式
但这些看起来都不够快。
那么,将张量的 X 最大值设置为零的最快方法是什么?
好吧,Pytorch 似乎有一个有用的运算符 torch.quantile() 在这里有很大帮助。
解(一维张量):
import torch
x = torch.randn(100)
y = torch.tensor(0.) #new value to assign
split_val = torch.quantile(x, 0.9)
x = torch.where(x < split_val, x, y)
我有正态分布的 Pytorch 二维张量。 有没有一种快速的方法可以使用 Python 使该张量的前 10% 最大值无效?
我在这里看到两种可能的方法:
- 将张量展平为 1d 并对其进行排序
- 使用一些本机 Python 运算符(for-if)的非矢量化方式
但这些看起来都不够快。
那么,将张量的 X 最大值设置为零的最快方法是什么?
好吧,Pytorch 似乎有一个有用的运算符 torch.quantile() 在这里有很大帮助。
解(一维张量):
import torch
x = torch.randn(100)
y = torch.tensor(0.) #new value to assign
split_val = torch.quantile(x, 0.9)
x = torch.where(x < split_val, x, y)