onnx模型修改添加Node

Posted 洪流之源

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了onnx模型修改添加Node相关的知识,希望对你有一定的参考价值。

import onnx
import copy

# onnx 插入新的Node
def insert_node(model, insert_node, follow_up_node):
    # 根据插入Node的输出修改后续node的输入
    follow_up_node.input[0] = insert_node.output[0]
    # 找到后续Node的索引位置,并将插入节点插入到graph中
    for follow_up_node_index, _follow_up_node in enumerate(model.graph.node):
        if _follow_up_node == follow_up_node:
            print("follow_up_node_index: ", follow_up_node_index)
            model.graph.node.insert(follow_up_node_index, insert_node)
            break

if __name__ == '__main__':
    src_onnx_model_path = './models/onnx/my_lprnet_model.onnx'
    dst_onnx_model_path = './models/onnx/new.onnx'
    onnx_model = onnx.load(src_onnx_model_path)
    graph = onnx_model.graph
    node  = graph.node

    # 临时节点方便后续修改
    temp_node1 = None
    temp_node2 = None

    for i in range(len(node)):
        if node[i].op_type == 'Transpose':
            print(i, node[i].name)

    for i in range(len(node)):
        # 修改Transpose_16 维度参数
        if node[i].op_type == 'Transpose' and node[i].name == "Transpose_16":
            node[i].attribute[0].ints[1] = 2
            node[i].attribute[0].ints[2] = 3
            node[i].attribute[0].ints[3] = 1
            # 深拷贝Node,获得temp_node1,后续对temp_node1参数修改获得新的transpose_17_node
            temp_node1 = copy.deepcopy(node[i])
        # 修改Transpose_38 维度参数
        if node[i].op_type == 'Transpose' and node[i].name == "Transpose_38":
            node[i].attribute[0].ints[1] = 2
            node[i].attribute[0].ints[2] = 3
            node[i].attribute[0].ints[3] = 1
            # 深拷贝Node,获得temp_node2,后续对temp_node2参数修改获得新的Transpose_39_node
            temp_node2 = copy.deepcopy(node[i])

    # 修改temp_node1参数得到新的transpose_17_node
    transpose_17_node = temp_node1
    transpose_17_node.name = 'Transpose_17'
    transpose_17_node.input[0] = '79'
    transpose_17_node.output[0] = 'transpose_17_output'

    follow_up_node = None
    for i in range(len(node)):
        if node[i].op_type == 'Conv' and node[i].name == "Conv_17":
            follow_up_node = node[i]
            break

    insert_node(onnx_model, transpose_17_node, follow_up_node)

    # 修改temp_node2参数得到新的Transpose_39_node
    Transpose_39_node = temp_node2
    Transpose_39_node.name = 'Transpose_39'
    Transpose_39_node.input[0] = '101'
    Transpose_39_node.output[0] = 'transpose_39_output'

    follow_up_node = None
    for i in range(len(node)):
        if node[i].op_type == 'Conv' and node[i].name == "Conv_39":
            follow_up_node = node[i]
            break

    insert_node(onnx_model, Transpose_39_node, follow_up_node)
    
    onnx.checker.check_model(onnx_model)
    onnx.save(onnx_model, dst_onnx_model_path)

 

以上是关于onnx模型修改添加Node的主要内容,如果未能解决你的问题,请参考以下文章

pytorch模型转换为rknn模型,使用npu推理

Pytorch的pth模型转onnx,再用ONNX Runtime调用推理(附python代码)

csharp通过onnx使用sklearn的模型

yolov5 pt 模型 导出 onnx

yolov5 pt 模型 导出 onnx

yolov5 pt 模型 导出 onnx