PyTorch 中两个概率分布的 KL 散度
KL Divergence for two probability distributions in PyTorch
我有两个概率分布。我应该如何在 PyTorch 中找到它们之间的 KL 散度?常规交叉熵只接受整数标签。
是的,PyTorch 在torch.nn.functional
下有一个名为kl_div
的方法可以直接计算张量之间的KL-散度。假设您有相同形状的张量 a
和 b
。您可以使用以下代码:
import torch.nn.functional as F
out = F.kl_div(a, b)
更多详细信息,请参阅上述方法文档。
如果你有两个 pytorch distribution
object. Then you are better off using the function torch.distributions.kl.kl_divergence(p, q)
. For documentation follow the link
形式的概率分布
函数kl_div
与wiki的解释不一样
我使用以下:
# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])
(P * (P / Q).log()).sum()
# tensor(0.0863), 10.2 µs ± 508
F.kl_div(Q.log(), P, None, None, 'sum')
# tensor(0.0863), 14.1 µs ± 408 ns
与kl_div
相比,更快
如果使用 Torch 发行版
mu = torch.Tensor([0] * 100)
sd = torch.Tensor([1] * 100)
p = torch.distributions.Normal(mu,sd)
q = torch.distributions.Normal(mu,sd)
out = torch.distributions.kl_divergence(p, q).mean()
out.tolist() == 0
True
如果您使用的是正态分布,那么下面的代码将直接比较两个分布本身:
p = torch.distributions.normal.Normal(p_mu, p_std)
q = torch.distributions.normal.Normal(q_mu, q_std)
loss = torch.distributions.kl_divergence(p, q)
p 和 q 是两个张量对象。
这段代码可以工作,不会给出任何 NotImplementedError。
我有两个概率分布。我应该如何在 PyTorch 中找到它们之间的 KL 散度?常规交叉熵只接受整数标签。
是的,PyTorch 在torch.nn.functional
下有一个名为kl_div
的方法可以直接计算张量之间的KL-散度。假设您有相同形状的张量 a
和 b
。您可以使用以下代码:
import torch.nn.functional as F
out = F.kl_div(a, b)
更多详细信息,请参阅上述方法文档。
如果你有两个 pytorch distribution
object. Then you are better off using the function torch.distributions.kl.kl_divergence(p, q)
. For documentation follow the link
函数kl_div
与wiki的解释不一样
我使用以下:
# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])
(P * (P / Q).log()).sum()
# tensor(0.0863), 10.2 µs ± 508
F.kl_div(Q.log(), P, None, None, 'sum')
# tensor(0.0863), 14.1 µs ± 408 ns
与kl_div
相比,更快
如果使用 Torch 发行版
mu = torch.Tensor([0] * 100)
sd = torch.Tensor([1] * 100)
p = torch.distributions.Normal(mu,sd)
q = torch.distributions.Normal(mu,sd)
out = torch.distributions.kl_divergence(p, q).mean()
out.tolist() == 0
True
如果您使用的是正态分布,那么下面的代码将直接比较两个分布本身:
p = torch.distributions.normal.Normal(p_mu, p_std)
q = torch.distributions.normal.Normal(q_mu, q_std)
loss = torch.distributions.kl_divergence(p, q)
p 和 q 是两个张量对象。
这段代码可以工作,不会给出任何 NotImplementedError。