ONNX基本操作
Posted 洪流之源
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ONNX基本操作相关的知识,希望对你有一定的参考价值。
1. Pytorch导出ONNX
torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
enable_onnx_checker=True, use_external_data_format=False):
参数比较多,但常用的有如下几个:
model: pytorch模型
args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数
f: 导出的onnx模型文件路径
export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可
verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息
input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表
output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表
opset_version: 导出onnx时参考的onnx算子集版本
dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码
如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的
import torch
import torch.nn as nn
import torch.onnx
import os
# 定义一个模型
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.relu = nn.ReLU()
self.conv.weight.data.fill_(1) # 权重被初始化为1
self.conv.bias.data.fill_(0) # 偏置被初始化为0
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
model = Model()
dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
model,
# 输入给model的数据,因为是元组类型,因此用括号
(dummy,),
# 导出的onnx文件路径
"demo.onnx",
# 打印导出过程详细信息
verbose=True,
# 为输入和输出节点指定名称,方便后面查看或者操作
input_names=["image"],
output_names=["output"],
# 导出时参考的onnx算子集版本
opset_version=11,
# 设置batch、height、width3个维度是动态的,
# 在onnx中会将其维度赋值为-1,
# 通常,我们只设置batch为动态,其它的避免动态
dynamic_axes=
"image": 0: "batch", 2: "height", 3: "width",
"output": 0: "batch", 2: "height", 3: "width",
)
print("Done.!")
2. netron可视化
netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。
以上是关于ONNX基本操作的主要内容,如果未能解决你的问题,请参考以下文章
大数据必学Java基础(七十):不要用字符流去操作非文本文件