如果我想让 OpenCV dnn 模块加载 PyTorch 模型,我应该如何保存它
How should I save the model of PyTorch if I want it loadable by OpenCV dnn module
我用 PyTorch 训练了一个简单的分类模型并用 opencv3.3 加载它,但它抛出异常并说
OpenCV Error: The function/feature is not implemented (Unsupported Lua type) in readObject, file
/home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp,
line 797
/home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp:797:
error: (-213) Unsupported Lua type in function readObject
模型定义
class conv_block(nn.Module):
def __init__(self, in_filter, out_filter, kernel):
super(conv_block, self).__init__()
self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
self.batchnorm = nn.BatchNorm2d(out_filter)
self.maxpool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.maxpool(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = conv_block(3, 6, 3)
self.conv2 = conv_block(6, 16, 3)
self.fc1 = nn.Linear(16 * 8 * 8, 120)
self.bn1 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.bn2 = nn.BatchNorm1d(84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return x
该模型仅使用Conv2d、ReLU、BatchNorm2d、MaxPool2d和Linear层,opencv3.3支持所有层
我用state_dict
保存
torch.save(net.state_dict(), 'cifar10_model')
用c++加载为
std::string const model_file("/home/some_folder/cifar10_model");
std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);
我想我用错误的方式保存了模型,为了使用 OpenCV 加载,保存 PyTorch 模型的正确方法是什么?谢谢
编辑:
我用另一种方式保存了模型,也加载不出来
torch.save(net, 'cifar10_model.net')
这是一个错误吗?还是我做错了什么?
我找到了答案,opencv3.3 不支持 PyTorch (https://github.com/pytorch/pytorch) but pytorch (https://github.com/hughperkins/pytorch),这是一个很大的惊喜,我从来不知道还有另一个版本的 pytorch 存在(看起来像一个死项目,很久了还没有更新),我希望他们能在 wiki 上提到他们支持的 pytorch。
我用 PyTorch 训练了一个简单的分类模型并用 opencv3.3 加载它,但它抛出异常并说
OpenCV Error: The function/feature is not implemented (Unsupported Lua type) in readObject, file /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp, line 797 /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp:797: error: (-213) Unsupported Lua type in function readObject
模型定义
class conv_block(nn.Module):
def __init__(self, in_filter, out_filter, kernel):
super(conv_block, self).__init__()
self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
self.batchnorm = nn.BatchNorm2d(out_filter)
self.maxpool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.batchnorm(x)
x = F.relu(x)
x = self.maxpool(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = conv_block(3, 6, 3)
self.conv2 = conv_block(6, 16, 3)
self.fc1 = nn.Linear(16 * 8 * 8, 120)
self.bn1 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.bn2 = nn.BatchNorm1d(84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return x
该模型仅使用Conv2d、ReLU、BatchNorm2d、MaxPool2d和Linear层,opencv3.3支持所有层
我用state_dict
保存torch.save(net.state_dict(), 'cifar10_model')
用c++加载为
std::string const model_file("/home/some_folder/cifar10_model");
std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);
我想我用错误的方式保存了模型,为了使用 OpenCV 加载,保存 PyTorch 模型的正确方法是什么?谢谢
编辑:
我用另一种方式保存了模型,也加载不出来
torch.save(net, 'cifar10_model.net')
这是一个错误吗?还是我做错了什么?
我找到了答案,opencv3.3 不支持 PyTorch (https://github.com/pytorch/pytorch) but pytorch (https://github.com/hughperkins/pytorch),这是一个很大的惊喜,我从来不知道还有另一个版本的 pytorch 存在(看起来像一个死项目,很久了还没有更新),我希望他们能在 wiki 上提到他们支持的 pytorch。