CrossEntropyLoss 等价于 LogSoftmax + NLLLoss

CrossEntropyLoss equivalence to LogSoftmax + NLLLoss

根据 docsCrossEntropyLoss 标准结合了 LogSoftmax 函数和 NLLLoss 标准。

一切都很好,但测试它似乎并不能证实这一说法(即断言失败):

model_nll = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2),
                          nn.LogSoftmax(dim=1))


model_ce = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2),
                          nn.LogSoftmax(dim=1))

loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()

t = torch.rand(1,3072)
target = torch.tensor([1])

with torch.no_grad():
    loss_nll = loss_fn_nll(model_nll(t), target)
    loss_ce = loss_fn_ce(model_ce(t), target)
    assert torch.eq(loss_nll, loss_ce)

我显然遗漏了一些基本的东西。

以下断言通过:

model = nn.Sequential(
    nn.Linear(3072, 1024),
    nn.Tanh(),
    nn.Linear(1024, 512),
    nn.Tanh(),
    nn.Linear(512, 128),
    nn.Tanh(),
    nn.Linear(128, 2),
)


loss_fn_nll = nn.NLLLoss()
loss_fn_ce = nn.CrossEntropyLoss()

t = torch.rand(1, 3072)
target = torch.tensor([1])

with torch.no_grad():

    loss_nll = loss_fn_nll(nn.LogSoftmax(dim=1)(model(t)), target)
    loss_ce = loss_fn_ce(model(t), target)

    assert torch.eq(loss_nll, loss_ce)

我假设权重是在原始问题的两个网络中随机生成的。即使 torch.manual_seed(0) 它仍然不等同。

如您所见,权重是随机初始化的。

让两个模块共享相同权重的一种方法是简单地导出 state_dict the state of one and set it on the other with load_state_dict

这是单行的:

>>> model_ce.load_state_dict(model_nll.state_dict())
import torch
import torch.nn as nn
model_nll = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2),
                          nn.LogSoftmax(dim=1))


model_ce = nn.Sequential(nn.Linear(3072, 1024),
                          nn.Tanh(),
                          nn.Linear(1024, 512),
                          nn.Tanh(),
                          nn.Linear(512, 128),
                          nn.Tanh(),
                          nn.Linear(128, 2)
)

model_nll.load_state_dict(dict(model_ce.named_parameters()))


loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()

t = torch.rand(1,3072)
target = torch.tensor([1])

with torch.no_grad():
    loss_nll = loss_fn_nll(model_nll(t), target)
    loss_ce = loss_fn_ce(model_ce(t), target)

    print(loss_nll, loss_ce)
    assert torch.eq(loss_nll, loss_ce)

您的代码有两个问题:

  • 两个模型的权重必须相同。初始化总是随机的,因此您必须强制它们相同。
  • 您不应在 model_ce 中添加 LogSoftmax,它已在 CrossEntropyLoss 内部计算。这使它在数值上更稳定:它简化了导数并允许应用 log-sum-exp 技巧。