你怎么知道 Pytorch Save 是否只包含模型 and/or 权重?
How do you know if a Pytorch Save contains a model and/or just the weights?
我是 pytorch 的新手,这可能是版本问题,但我看到使用了 torch.load 和 torch.load_state_dict,但在这两种情况下,文件扩展名都很常见“.pth”
我创建的模型,我可以通过 torch.Save 和 torch.Load 保存和加载它们并调用 model.eval()
我有另一个模型文件,我很确定它只是状态字典,因为 model.eval() 在加载后失败。
我如何检查文件并知道其中有一个完整的模型?
非常感谢。
据我所知,没有万无一失的方法来解决这个问题。 torch.save
在后台使用 Python 的 pickle(参考:Pytorch docs),因此用户可以保存任意 Python 对象。例如,以下代码将状态字典包装在字典中:
# example from https://github.com/lucidrains/lightweight-gan/blob/fce20938562a0cc289c915f7317722a8241abd37/lightweight_gan/lightweight_gan.py#L1437
save_data = {
'GAN': self.GAN.state_dict(),
'version': __version__,
'G_scaler': self.G_scaler.state_dict(),
'D_scaler': self.D_scaler.state_dict()
}
torch.save(save_data, self.model_name(num))
如果有帮助,状态指令本身就是 OrderedDict
对象。如果 isinstance(model, collections.OrderedDict)
returns 为真,您可以相当确信 model
是状态命令。 (记得要import collections
)
模型本身是 torch.nn.Module
的子类,因此您可以通过验证 isinstance(model, torch.nn.Module)
returns 是否为真来检查某物是否是模型。
我是 pytorch 的新手,这可能是版本问题,但我看到使用了 torch.load 和 torch.load_state_dict,但在这两种情况下,文件扩展名都很常见“.pth”
我创建的模型,我可以通过 torch.Save 和 torch.Load 保存和加载它们并调用 model.eval()
我有另一个模型文件,我很确定它只是状态字典,因为 model.eval() 在加载后失败。
我如何检查文件并知道其中有一个完整的模型?
非常感谢。
据我所知,没有万无一失的方法来解决这个问题。 torch.save
在后台使用 Python 的 pickle(参考:Pytorch docs),因此用户可以保存任意 Python 对象。例如,以下代码将状态字典包装在字典中:
# example from https://github.com/lucidrains/lightweight-gan/blob/fce20938562a0cc289c915f7317722a8241abd37/lightweight_gan/lightweight_gan.py#L1437
save_data = {
'GAN': self.GAN.state_dict(),
'version': __version__,
'G_scaler': self.G_scaler.state_dict(),
'D_scaler': self.D_scaler.state_dict()
}
torch.save(save_data, self.model_name(num))
如果有帮助,状态指令本身就是 OrderedDict
对象。如果 isinstance(model, collections.OrderedDict)
returns 为真,您可以相当确信 model
是状态命令。 (记得要import collections
)
模型本身是 torch.nn.Module
的子类,因此您可以通过验证 isinstance(model, torch.nn.Module)
returns 是否为真来检查某物是否是模型。