CrossEntropyLoss 等价于 LogSoftmax + NLLLoss
CrossEntropyLoss equivalence to LogSoftmax + NLLLoss
根据 docs,CrossEntropyLoss
标准结合了 LogSoftmax
函数和 NLLLoss
标准。
一切都很好,但测试它似乎并不能证实这一说法(即断言失败):
model_nll = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
model_ce = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()
t = torch.rand(1,3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(model_nll(t), target)
loss_ce = loss_fn_ce(model_ce(t), target)
assert torch.eq(loss_nll, loss_ce)
我显然遗漏了一些基本的东西。
以下断言通过:
model = nn.Sequential(
nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
)
loss_fn_nll = nn.NLLLoss()
loss_fn_ce = nn.CrossEntropyLoss()
t = torch.rand(1, 3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(nn.LogSoftmax(dim=1)(model(t)), target)
loss_ce = loss_fn_ce(model(t), target)
assert torch.eq(loss_nll, loss_ce)
我假设权重是在原始问题的两个网络中随机生成的。即使 torch.manual_seed(0)
它仍然不等同。
如您所见,权重是随机初始化的。
让两个模块共享相同权重的一种方法是简单地导出 state_dict
the state of one and set it on the other with load_state_dict
。
这是单行的:
>>> model_ce.load_state_dict(model_nll.state_dict())
import torch
import torch.nn as nn
model_nll = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
model_ce = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2)
)
model_nll.load_state_dict(dict(model_ce.named_parameters()))
loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()
t = torch.rand(1,3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(model_nll(t), target)
loss_ce = loss_fn_ce(model_ce(t), target)
print(loss_nll, loss_ce)
assert torch.eq(loss_nll, loss_ce)
您的代码有两个问题:
- 两个模型的权重必须相同。初始化总是随机的,因此您必须强制它们相同。
- 您不应在 model_ce 中添加 LogSoftmax,它已在 CrossEntropyLoss 内部计算。这使它在数值上更稳定:它简化了导数并允许应用 log-sum-exp 技巧。
根据 docs,CrossEntropyLoss
标准结合了 LogSoftmax
函数和 NLLLoss
标准。
一切都很好,但测试它似乎并不能证实这一说法(即断言失败):
model_nll = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
model_ce = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()
t = torch.rand(1,3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(model_nll(t), target)
loss_ce = loss_fn_ce(model_ce(t), target)
assert torch.eq(loss_nll, loss_ce)
我显然遗漏了一些基本的东西。
以下断言通过:
model = nn.Sequential(
nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
)
loss_fn_nll = nn.NLLLoss()
loss_fn_ce = nn.CrossEntropyLoss()
t = torch.rand(1, 3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(nn.LogSoftmax(dim=1)(model(t)), target)
loss_ce = loss_fn_ce(model(t), target)
assert torch.eq(loss_nll, loss_ce)
我假设权重是在原始问题的两个网络中随机生成的。即使 torch.manual_seed(0)
它仍然不等同。
如您所见,权重是随机初始化的。
让两个模块共享相同权重的一种方法是简单地导出 state_dict
the state of one and set it on the other with load_state_dict
。
这是单行的:
>>> model_ce.load_state_dict(model_nll.state_dict())
import torch
import torch.nn as nn
model_nll = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
model_ce = nn.Sequential(nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2)
)
model_nll.load_state_dict(dict(model_ce.named_parameters()))
loss_fn_ce = nn.CrossEntropyLoss()
loss_fn_nll = nn.NLLLoss()
t = torch.rand(1,3072)
target = torch.tensor([1])
with torch.no_grad():
loss_nll = loss_fn_nll(model_nll(t), target)
loss_ce = loss_fn_ce(model_ce(t), target)
print(loss_nll, loss_ce)
assert torch.eq(loss_nll, loss_ce)
您的代码有两个问题:
- 两个模型的权重必须相同。初始化总是随机的,因此您必须强制它们相同。
- 您不应在 model_ce 中添加 LogSoftmax,它已在 CrossEntropyLoss 内部计算。这使它在数值上更稳定:它简化了导数并允许应用 log-sum-exp 技巧。