加载使用分布式数据并行保存的模型时出错

Error loading model saved with distributed data parallel

加载从分布式模式下保存的模型时,模型名称不同,导致此错误。我该如何解决?

  File "/code/src/bert_structure_prediction/model.py", line 36, in __init__                         
    self.load_state_dict(state_dict)                                                                
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict                                                                                               
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                       
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:                           
        Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.

模型名称不匹配的原因是因为DDP包裹了模型对象,导致在分布式数据并行模式下保存模型时层名称不同(具体来说,层名称会在前面加上module.到型号名称)。要解决此问题,请使用

torch.save(model.module.state_dict(), PATH)

而不是

torch.save(model.state_dict(), PATH)

从数据并行保存时。