onnx模型修改添加Node
Posted 洪流之源
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了onnx模型修改添加Node相关的知识,希望对你有一定的参考价值。
import onnx
import copy
# onnx 插入新的Node
def insert_node(model, insert_node, follow_up_node):
# 根据插入Node的输出修改后续node的输入
follow_up_node.input[0] = insert_node.output[0]
# 找到后续Node的索引位置,并将插入节点插入到graph中
for follow_up_node_index, _follow_up_node in enumerate(model.graph.node):
if _follow_up_node == follow_up_node:
print("follow_up_node_index: ", follow_up_node_index)
model.graph.node.insert(follow_up_node_index, insert_node)
break
if __name__ == '__main__':
src_onnx_model_path = './models/onnx/my_lprnet_model.onnx'
dst_onnx_model_path = './models/onnx/new.onnx'
onnx_model = onnx.load(src_onnx_model_path)
graph = onnx_model.graph
node = graph.node
# 临时节点方便后续修改
temp_node1 = None
temp_node2 = None
for i in range(len(node)):
if node[i].op_type == 'Transpose':
print(i, node[i].name)
for i in range(len(node)):
# 修改Transpose_16 维度参数
if node[i].op_type == 'Transpose' and node[i].name == "Transpose_16":
node[i].attribute[0].ints[1] = 2
node[i].attribute[0].ints[2] = 3
node[i].attribute[0].ints[3] = 1
# 深拷贝Node,获得temp_node1,后续对temp_node1参数修改获得新的transpose_17_node
temp_node1 = copy.deepcopy(node[i])
# 修改Transpose_38 维度参数
if node[i].op_type == 'Transpose' and node[i].name == "Transpose_38":
node[i].attribute[0].ints[1] = 2
node[i].attribute[0].ints[2] = 3
node[i].attribute[0].ints[3] = 1
# 深拷贝Node,获得temp_node2,后续对temp_node2参数修改获得新的Transpose_39_node
temp_node2 = copy.deepcopy(node[i])
# 修改temp_node1参数得到新的transpose_17_node
transpose_17_node = temp_node1
transpose_17_node.name = 'Transpose_17'
transpose_17_node.input[0] = '79'
transpose_17_node.output[0] = 'transpose_17_output'
follow_up_node = None
for i in range(len(node)):
if node[i].op_type == 'Conv' and node[i].name == "Conv_17":
follow_up_node = node[i]
break
insert_node(onnx_model, transpose_17_node, follow_up_node)
# 修改temp_node2参数得到新的Transpose_39_node
Transpose_39_node = temp_node2
Transpose_39_node.name = 'Transpose_39'
Transpose_39_node.input[0] = '101'
Transpose_39_node.output[0] = 'transpose_39_output'
follow_up_node = None
for i in range(len(node)):
if node[i].op_type == 'Conv' and node[i].name == "Conv_39":
follow_up_node = node[i]
break
insert_node(onnx_model, Transpose_39_node, follow_up_node)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, dst_onnx_model_path)
以上是关于onnx模型修改添加Node的主要内容,如果未能解决你的问题,请参考以下文章