torch.nn.functional.gumbel_softmax 的输入

input for torch.nn.functional.gumbel_softmax

假设我有一个名为 attn_weights 的张量,大小为 [1,a],其中的条目表示给定查询和 |a| 之间的注意力权重键。我想 select 最大的一个使用 torch.nn.functional.gumbel_softmax

我发现 docs about this function 将参数描述为 logits - [..., num_features] 非标准化对数概率 。我想知道 在将 attn_weights 传递给 gumbel_softmax 之前我是否应该使用 log 我发现 Wiki 定义了 logit=lg(p/1-p),它是与勉强对数不同。我想知道我应该将哪个传递给函数?

此外,我想知道如何在gumbel_softmax中选择tau,有什么指导方针吗?

I wonder whether should I take log of attn_weights before passing it into gumbel_softmax?

如果 attn_weights 是概率(总和为 1;例如,softmax 的输出),那么是。否则,没有。

I wonder how to choose tau in gumbel_softmax, any guidelines?

通常,它需要调整。文档中提供的参考资料可以帮助您。

来自Categorical Reparameterizaion with Gumbel-Softmax

  • 图 1,标题:

    ... (a) For low temperatures (τ = 0.1, τ = 0.5), the expected value of a Gumbel-Softmax random variable approaches the expected value of a categorical random variable with the same logits. As the temperature increases (τ = 1.0, τ = 10.0), the expected value converges to a uniform distribution over the categories.

  • 第 2.2 节,第 2 段(强调我的):

    While Gumbel-Softmax samples are differentiable, they are not identical to samples from the corresponding categorical distribution for non-zero temperature. For learning, there is a tradeoff between small temperatures, where samples are close to one-hot but the variance of the gradients is large, and large temperatures, where samples are smooth but the variance of the gradients is small (Figure 1). In practice, we start at a high temperature and anneal to a small but non-zero temperature.

  • 最后,他们提醒 reader 可以学习 tau:

    If τ is a learned parameter (rather than annealed via a fixed schedule), this scheme can be interpreted as entropy regularization (Szegedy et al., 2015; Pereyra et al., 2016), where the Gumbel-Softmax distribution can adaptively adjust the "confidence" of proposed samples during the training process.