加载使用分布式数据并行保存的模型时出错
Error loading model saved with distributed data parallel
加载从分布式模式下保存的模型时,模型名称不同,导致此错误。我该如何解决?
File "/code/src/bert_structure_prediction/model.py", line 36, in __init__
self.load_state_dict(state_dict)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:
Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.
模型名称不匹配的原因是因为DDP包裹了模型对象,导致在分布式数据并行模式下保存模型时层名称不同(具体来说,层名称会在前面加上module.
到型号名称)。要解决此问题,请使用
torch.save(model.module.state_dict(), PATH)
而不是
torch.save(model.state_dict(), PATH)
从数据并行保存时。
加载从分布式模式下保存的模型时,模型名称不同,导致此错误。我该如何解决?
File "/code/src/bert_structure_prediction/model.py", line 36, in __init__
self.load_state_dict(state_dict)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:
Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.
模型名称不匹配的原因是因为DDP包裹了模型对象,导致在分布式数据并行模式下保存模型时层名称不同(具体来说,层名称会在前面加上module.
到型号名称)。要解决此问题,请使用
torch.save(model.module.state_dict(), PATH)
而不是
torch.save(model.state_dict(), PATH)
从数据并行保存时。