ONNX基本操作

Posted 洪流之源

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ONNX基本操作相关的知识,希望对你有一定的参考价值。

1. Pytorch导出ONNX

torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, 
           input_names=None, output_names=None, aten=False, export_raw_ir=False, 
           operator_export_type=None, opset_version=None, _retain_param_name=True, 
           do_constant_folding=True, example_outputs=None, strip_doc_string=True, 
           dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, 
           enable_onnx_checker=True, use_external_data_format=False):

参数比较多,但常用的有如下几个:

model: pytorch模型

args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数

f: 导出的onnx模型文件路径

export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可

verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息

input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表

output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表

opset_version: 导出onnx时参考的onnx算子集版本

dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码

如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的

import torch
import torch.nn as nn
import torch.onnx
import os

# 定义一个模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(1, 1, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv.weight.data.fill_(1) # 权重被初始化为1
        self.conv.bias.data.fill_(0) # 偏置被初始化为0
    
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


model = Model()
dummy = torch.zeros(1, 1, 3, 3)

torch.onnx.export(
    model, 

    # 输入给model的数据,因为是元组类型,因此用括号
    (dummy,), 

    # 导出的onnx文件路径
    "demo.onnx", 

    # 打印导出过程详细信息
    verbose=True, 

    # 为输入和输出节点指定名称,方便后面查看或者操作
    input_names=["image"], 
    output_names=["output"], 

    # 导出时参考的onnx算子集版本
    opset_version=11, 

    # 设置batch、height、width3个维度是动态的,
    # 在onnx中会将其维度赋值为-1,
    # 通常,我们只设置batch为动态,其它的避免动态
    dynamic_axes=
        "image": 0: "batch", 2: "height", 3: "width",
        "output": 0: "batch", 2: "height", 3: "width",
    
)

print("Done.!")

 2. netron可视化

 netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。

以上是关于ONNX基本操作的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch导出ONNX踩坑指南

大数据必学Java基础(七十):不要用字符流去操作非文本文件

ONNX 和 TensorRT 模型中的参数数量和 FLOPS

ONNX 开始

由微软打造的深度学习开放联盟ONNX成立

算法工具-1.torch Pt模型转onnx(torch.onnx.export(m, d, onnx_path))