8.1 PyTorch模型迁移
Posted 王小小小草
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了8.1 PyTorch模型迁移相关的知识,希望对你有一定的参考价值。
欢迎订阅本专栏:《PyTorch深度学习实践》
订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html
- 第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础。
- 第三章:学习PyTorch如何读入各种外部数据
- 第四章:利用PyTorch从头到尾创建、训练、评估一个模型,理解与熟悉PyTorch实现模型的每个步骤,用到的模块与方法。
- 第五章:学习如何利用PyTorch提供的3种方法去创建各种模型结构。
- 第六章:利用PyTorch实现简单与经典的模型全过程:简单二分类、手写字体识别、词向量的实现、自编码器实现。
- 第七章:利用PyTorch实现复杂模型:翻译机(nlp领域)、生成对抗网络(GAN)、强化学习(RL)、风格迁移(cv领域)。
- 第八章:PyTorch的其他高级用法:模型在不同框架之间的迁移、可视化、多个GPU并行计算。
深度学习的框架层出不穷,各有优势,为了解决各个框架之间模型可以迁移与共用,微软和facebook共同发布了ONNX, Open Neural Exchange,开放式神经网络交换。比如在PyTorch上训练得到的模型,在其他框架上是不可以直接使用的,有了ONNX,我们可以先把PyTorch训练好的模型用ONNX来导出并保存成ONNX模型,再用其他框架将ONNX模型读入。这就好像市场交易,我种的大米不能直接变成好看的衣服,但是可以先把大米买了换钱,用钱再去换衣服。ONNX就是充当了货币的角色。
本章分为2节,首先对ONNX做一个简单的介绍,然后使用ONNX将PyTorch的模型迁移到另一个很有名的框架Caffe2上。
8.1.1 ONNX简介
(1)ONNX支持的框架
目前ONNX支持的框架有:PyTorch, Caffe2, Microsoft Cognitive Tookkit, Chainer, MATLAB, SAS, MXNet, PaddlePaddle.
非官方支持的框架:Tensorflow, Keras, Core ML, scikit-learn, XGBoost, LibSVM, nxnn.
基本上涵盖了目前常用的框架,所以ONNX是非常实用的。
(2)ONNX模型动物园
所谓的模型动物园是指收集和存放了很多模型的地方,大家可以使用动物园里的模型,也可以将自己的模型放到动物园里,实现了模型的共享。
ONNX模型动物园收集了许多业界领先水平或者非常有名常被使用的深度学习模型。地址是:http://github.com/onnx/models, 这些模型都是经过与训练的,可以被直接下载并加载到自己使用的框架中。点开链接,在readme中可以看到各类模型的分类,点进分类中可以看到该类别下的具体模型。若要下载,直接点击模型名称即可。
(3)可视化ONNX模型
可以用Netron这个工具来可视化ONNX模型。
可以在地址https://www.lutzoeder/ai/ 来下载安装桌面版本,也可以直接登陆https://lutzoeder.github.io/netron/ 来访问网页版本。在网页上上传从onnx下载的模型,就能可视化这个模型的结构以及参数。
8.1.2 使用ONNX将PyTorch的模型迁移到Caffe2
(1)安装ONNX
若安装了PyTorch 1.0,程序已经安装好了Caffe2,现在来额外安装ONNX: pip install onnx
(2)ONNX导出PyTorch模型
import torch
import torch.onnx
import torchvision
torch_model = torchvision.models.alexnet(pretrained=True)
x = torch.randn(1,3,244,244)
torch_out = torch.onnx._export(torch_model, x, "model/alexnet.onnx",verbose=True)
(3)检验ONNX模型
import onnx
model = onnx.load("model/alexnet.onnx") # 加载模型
onnx.checker.check_model(model) # 检测
onnx.helper.printable_graph(model.graph) # 打印模型
(4)ONNX模型导入Caffe2
import numpy as np
import caffe2.python.onnx.backend as onnx_caffe2_backend
prepared_backend = onnx_caffe2_backend.prepare(model)
w = (model.graph.input[0].name: x.data.numpy())
c2_out = prepared_backend.run(w)[0]
以上是关于8.1 PyTorch模型迁移的主要内容,如果未能解决你的问题,请参考以下文章
在 PyTorch 中加载迁移学习模型进行推理的正确方法是啥?
从 `pytorch-pretrained-bert` 迁移到 `pytorch-transformers` 关于模型()输出的问题