Pytorch:可微分计数

Pytorch: Differentiable counting

我需要统计一个tensor中满足条件的元素个数,比如统计年龄==60的人数,或者年龄>=50的人数,请问是否有可微分的逼近计数功能?

age >= 50

返回的布尔张量使用 torch.Tensor.sumtorch.sum
age = torch.arange(75)

(age >= 50).sum()
tensor(25)