Pytorch: AttributeError: 'function' object has no attribute 'copy'
Pytorch: AttributeError: 'function' object has no attribute 'copy'
我正在尝试加载模型 state_dict
我在 Google Colab GPU 上训练,这是我加载模型的代码:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)))
model = model.to(device)
model.eval()
这是错误:
state_dict = state_dict.copy()
AttributeError: 'function' object has no attribute 'copy'
火炬:
>>> import torch
>>> print (torch.__version__)
1.4.0
>>> import torchvision
>>> print (torchvision.__version__)
0.5.0
求助,找遍了都无果
[完整错误详情][1] https://i.stack.imgur.com/s22DL.png
我猜这是你做错了。
您保存了函数
torch.save(model.state_dict, 'model_state.pth')
而不是 state_dict()
torch.save(model.state_dict(), 'model_state.pth')
否则,一切都应该按预期工作。 (我在 Colab 上测试了以下代码)
将 model.state_dict()
替换为 model.state_dict
以重现错误
import copy
model = TheModelClass()
torch.save(model.state_dict(), 'model_state.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(copy.deepcopy(torch.load("model_state.pth",device)))
我正在尝试加载模型 state_dict
我在 Google Colab GPU 上训练,这是我加载模型的代码:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)))
model = model.to(device)
model.eval()
这是错误:
state_dict = state_dict.copy()
AttributeError: 'function' object has no attribute 'copy'
火炬:
>>> import torch
>>> print (torch.__version__)
1.4.0
>>> import torchvision
>>> print (torchvision.__version__)
0.5.0
求助,找遍了都无果
[完整错误详情][1] https://i.stack.imgur.com/s22DL.png
我猜这是你做错了。 您保存了函数
torch.save(model.state_dict, 'model_state.pth')
而不是 state_dict()
torch.save(model.state_dict(), 'model_state.pth')
否则,一切都应该按预期工作。 (我在 Colab 上测试了以下代码)
将 model.state_dict()
替换为 model.state_dict
以重现错误
import copy
model = TheModelClass()
torch.save(model.state_dict(), 'model_state.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(copy.deepcopy(torch.load("model_state.pth",device)))