如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它

Posted

技术标签:

【中文标题】如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它【英文标题】:How should I save the model of PyTorch if I want it loadable by OpenCV dnn module 【发布时间】:2018-02-06 08:20:30 【问题描述】:

我通过 PyTorch 训练了一个简单的分类模型并通过 opencv3.3 加载它,但是它抛出异常并说

OpenCV 错误:函数/特性未在 readObject、文件中实现(不支持的 Lua 类型) /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp, 第 797 行 /home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp:797: 错误:(-213) 函数 readObject 中不支持的 Lua 类型

模型定义

class conv_block(nn.Module):
    def __init__(self, in_filter, out_filter, kernel):
        super(conv_block, self).__init__()

        self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
        self.batchnorm = nn.BatchNorm2d(out_filter)
        self.maxpool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = F.relu(x)
        x = self.maxpool(x)

        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = conv_block(3, 6, 3)
        self.conv2 = conv_block(6, 16, 3)
        self.fc1 = nn.Linear(16 * 8 * 8, 120)
        self.bn1 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn2 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

此模型仅使用 Conv2d、ReLU、BatchNorm2d、MaxPool2d 和 Linear 层,opencv3.3 支持每一层

我用 state_dict 保存

torch.save(net.state_dict(), 'cifar10_model')

通过c++加载为

std::string const model_file("/home/some_folder/cifar10_model");

std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);

我想我用错误的方式保存模型,保存 PyTorch 模型以便使用 OpenCV 加载的正确方法是什么?谢谢

编辑:

我用另一种方式保存模型,但也无法加载

torch.save(net, 'cifar10_model.net')

这是一个错误吗?还是我做错了什么?

【问题讨论】:

【参考方案1】:

我找到了答案,opencv3.3 不支持 PyTorch (https://github.com/pytorch/pytorch) 但 pytorch (https://github.com/hughperkins/pytorch),这是一个很大的惊喜,我不知道还有另一个版本的 pytorch 存在(看起来像一个死项目,好久没更新了),希望能在wiki上提一下他们支持的pytorch。

【讨论】:

有趣,支持早期 pytorch 的 opencv 对我来说也是新信息 :)。 (我写了较早的pytorch,在正式创建之前) 我觉得可行的办法是将pytorch模型转成caffe,在opencv中读取

以上是关于如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它的主要内容,如果未能解决你的问题,请参考以下文章

OpenCV学习笔记 - DNN模块使用(含源码详细解释)

OpenCV DNN模块——从TensorFlow模型导出到OpenCV部署详解

CUDA(GPU) 作为 OpenCV 后端

哪个caffe版本opencv在dnn模块中使用?

OpenCV DNN 模块-风格迁移

如何从 .pb 文件为 opencv 中的 dnn 模块生成 .pbtxt 文件?