在 pytorch 中加载模型权重时缺少键
Missing keys when loading the model weight in pytorch
我打算从 pth 文件加载权重,例如,
model = my_model()
model.load_state_dict(torch.load("../input/checkpoint/checkpoint.pth")
但是,这里报错,说:
RuntimeError: Error(s) in loading state_dict for my_model:
Missing key(s) in state_dict: "att.in_proj_weight", "att.in_proj_bias", "att.out_proj.weight", "att.out_proj.bias".
Unexpected key(s) in state_dict: "in_proj_weight", "in_proj_bias", "out_proj.weight", "out_proj.bias".
似乎我的模型的参数名称与 state_dict
中存储的参数名称不同。在这种情况下,我应该如何使它们保持一致?
您可以创建新字典并修改不带 att.
前缀的键,您可以按如下方式将新字典加载到您的模型中:
state_dict = torch.load('path\to\checkpoint.pth')
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in state_dict.items():
key = key[4:] # remove `att.`
new_state_dict[key] = value
# load params
model = my_model()
model.load_state_dict(new_state_dict)
我打算从 pth 文件加载权重,例如,
model = my_model()
model.load_state_dict(torch.load("../input/checkpoint/checkpoint.pth")
但是,这里报错,说:
RuntimeError: Error(s) in loading state_dict for my_model:
Missing key(s) in state_dict: "att.in_proj_weight", "att.in_proj_bias", "att.out_proj.weight", "att.out_proj.bias".
Unexpected key(s) in state_dict: "in_proj_weight", "in_proj_bias", "out_proj.weight", "out_proj.bias".
似乎我的模型的参数名称与 state_dict
中存储的参数名称不同。在这种情况下,我应该如何使它们保持一致?
您可以创建新字典并修改不带 att.
前缀的键,您可以按如下方式将新字典加载到您的模型中:
state_dict = torch.load('path\to\checkpoint.pth')
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in state_dict.items():
key = key[4:] # remove `att.`
new_state_dict[key] = value
# load params
model = my_model()
model.load_state_dict(new_state_dict)