无法从在线资源下载保存的模型,pickle 错误

Unable to download saved model from online resource, pickle error

我无法下载和使用之前从在线存储库中保存的模型。这是代码:


model = Model().double()   # Model is defined in another class
state_dict = torch.hub.load_state_dict_from_url(r'https://filebin.net/j2977ux7kts41aft/checkpoint_best.pt?t=wjbujfoo')
model.load_state_dict(state_dict)
model.eval()

这给了我以下错误:

Traceback (most recent call last):
  File "/path/file.py", line 47, in <module>
    state_dict = torch.hub.load_state_dict_from_url(r'https://filebin.net/j2977ux7kts41aft/checkpoint_best.pt?t=wjbujfoo')
  File "anaconda3/envs/torch_env/lib/python3.6/site-packages/torch/hub.py", line 466, in load_state_dict_from_url
    return torch.load(cached_file, map_location=map_location)
  File "/anaconda3/envs/torch_env/lib/python3.6/site-packages/torch/serialization.py", line 386, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)
  File "anaconda3/envs/torch_env/lib/python3.6/site-packages/torch/serialization.py", line 563, in _load
    magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '\x0a'.

模型位于: https://filebin.net/j2977ux7kts41aft/checkpoint_best.pt?t=wjbujfoo

请注意,我可以完美地手动下载,然后使用torch.load(path)加载它不会出错,但我需要从代码中完成!会不会是从 url 下载时的序列化以某种方式弄乱了 pickle 编码?

编辑:我不必使用 filebin,任何支持我尝试做的事情的在线存储就足够了。

这段带有 'download button' 和 'map_location' 参数的 link 代码对我来说很好用:

state_dict = torch.hub.load_state_dict_from_url(r'https://filebin.net/j2977ux7kts41aft/checkpoint_best.pt?t=wjbujfoo', map_location=torch.device('cpu'))

确实是环境配置的问题。我使用 PyTorch 1.0.2 创建模型,然后更新到 1.2.0 以便使用 torch.hub。这给了我泡菜错误。在 1.2.0 训练新模型后,错误现在消失了。

希望这对以后的人有所帮助:)