PyTorch model saving error: "Can't pickle local object"

PyTorch model saving error: "Can't pickle local object"

当我尝试用这段代码保存 PyTorch 模型时:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我收到以下错误:

    E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...

      "type " + obj.__name__ + ". It won't be checked "
    Can't pickle local object 'trainModel.<locals>.Net'

当我尝试用这段代码保存 PyTorch 模型时:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我没有收到任何错误,但我想保存 ANN class。我怎么解决这个问题?另外,我可以在之前的其他项目中保存第一个结构的模型

你不能! torch.save 仅保存对象 state_dict()

当您使用以下内容时:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

您正在尝试保存模型本身,但此数据保存在 model.state_dict() 中,并且在使用 state_dict 加载模型时,您应该首先启动一个模型对象。

这正是第二种方法正常工作的原因:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我建议阅读以下 link 中有关如何正确 save\load 模型的 pytorch 文档: https://pytorch.org/tutorials/beginner/saving_loading_models.html

按照通常的正确方法保存和加载模型 https://pytorch.org/tutorials/beginner/saving_loading_models.html 如果您有要保存的 args 或 dicts,也许还有 lambda 函数,有时我会使用 dill 并且错误消失了。例如

def save_for_meta_learning(args, ckpt_filename='ckpt.pt'):
    if is_lead_worker(args.rank):
        import dill
        args.logger.save_current_plots_and_stats()
        # - ckpt
        assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
        args_pickable = uutils.make_args_pickable(args)
        # args.meta_learner.args = args_pickable
        f: nn.Module = get_model_from_ddp(args.base_model)
        # pickle vs torch_uu.save https://discuss.pytorch.org/t/advantages-disadvantages-of-using-pickle-module-to-save-models-vs-torch-save/79016
        torch.save({'training_mode': args.training_mode,  # its or epochs
                    'it': args.it,
                    'epoch_num': args.epoch_num,
                    # 'args': args_pickable,
                    'args_pickable': args_pickable,
                    # 'meta_learner': args.meta_learner,
                    'meta_learner_str': str(args.meta_learner),
                    # 'f': f,
                    'f_state_dict': f.state_dict(),
                    'f_str': str(f),
                    # 'f_modules': f._modules,
                    # 'f_modules_str': str(f._modules),
                    'outer_opt_state_dict': args.outer_opt.state_dict()
                    },
                   pickle_module=dill,
                   f=args.log_root / ckpt_filename)