火炬:log_softmax 基地 2?

pytorch: log_softmax base 2?

我想使用 log base 2 从 PyTorch 的 logit 输出中获取意外值。

给定一个 logits 张量,一种方法是:

probs = nn.functional.softmax(logits, dim = 2)
surprisals = -torch.log2(probs)

不过PyTorch提供了结合log和softmax的函数,比上面的更快:

surprisals = -nn.functional.log_softmax(logits, dim = 2)

但这似乎是 return 以 e 为底的值,这是我不想要的。是否有类似 log_softmax 的函数,但它使用基数 2?我已经尝试了 log2_softmaxlog_softmax2,这两个似乎都不起作用,并且没有找到在线文档。

利用以下数学恒等式可以很容易地改变对数底数这一事实怎么样

就是 F.log_softmax() 给你的。您只需

surprisals = - (1 / torch.log(2.)) * nn.functional.log_softmax(logits, dim = 2)

它只是一个标量乘法。所以,它几乎没有任何性能损失。