你怎么知道 Pytorch Save 是否只包含模型 and/or 权重?

How do you know if a Pytorch Save contains a model and/or just the weights?

我是 pytorch 的新手,这可能是版本问题,但我看到使用了 torch.load 和 torch.load_state_dict,但在这两种情况下,文件扩展名都很常见“.pth”

我创建的模型,我可以通过 torch.Save 和 torch.Load 保存和加载它们并调用 model.eval()

我有另一个模型文件,我很确定它只是状态字典,因为 model.eval() 在加载后失败。

我如何检查文件并知道其中有一个完整的模型?

非常感谢。

据我所知,没有万无一失的方法来解决这个问题。 torch.save 在后台使用 Python 的 pickle(参考:Pytorch docs),因此用户可以保存任意 Python 对象。例如,以下代码将状态字典包装在字典中:

# example from https://github.com/lucidrains/lightweight-gan/blob/fce20938562a0cc289c915f7317722a8241abd37/lightweight_gan/lightweight_gan.py#L1437
save_data = {
    'GAN': self.GAN.state_dict(),
    'version': __version__,
    'G_scaler': self.G_scaler.state_dict(),
    'D_scaler': self.D_scaler.state_dict()
}
torch.save(save_data, self.model_name(num))

如果有帮助,状态指令本身就是 OrderedDict 对象。如果 isinstance(model, collections.OrderedDict) returns 为真,您可以相当确信 model 是状态命令。 (记得要import collections

模型本身是 torch.nn.Module 的子类,因此您可以通过验证 isinstance(model, torch.nn.Module) returns 是否为真来检查某物是否是模型。