从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)
Posted 走召大爷
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)相关的知识,希望对你有一定的参考价值。
最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。
Tensorflow官方提供的Tensorboard可以可视化神经网络结构图,但是说实话,我几乎从来不用。主要是因为Tensorboard中查看到的图结构太混乱了,包含了网络中所有的计算节点(读取数据节点、网络节点、loss计算节点等等)。更可怕的是,如果一个计算节点是由多个基础计算(如加减乘除等)构成,那么在Tensorboard中会将基础计算节点显示而不是作为一个整体显示(典型的如Squeeze计算节点)。最近为了排查网络结构BUG花费一周时间,因此,狠下心来决定自己写一个工具,将Tensorflow中的图以最简单的方式显示最关键的网络结构。
1 Tensor对象与Operation对象
Tensorflow中,Tensor对象主要用于存储数据如常量和变量(训练参数),Operation对象是计算节点,如卷积计算、反卷积计算、ReLU等等。每一个Operation对象均有输入和输出Tensor,同理,每个Tensor对象均有对应生成该Tensor的Operation对象和使用该Tensor对象作为输入的Operation对象。Tensor和Operation对象内均有相关属性和函数来获取其关联的Operation和Tensor对象,相关属性如下所示。
- Tensor对象的op属性指向生成该Tensor的Operation对象。
- Tensor对象的consumers()函数获取使用该Tensor对象作为输入的Operation对象。
- Operation对象的inputs属性指向该计算节点的输入Tensor对象。
- Operation对象的outputs属性执行该计算节点的输出Tensor对象。
如下图所示的网络结构中,调用Tensor_2
对象的consumers()
函数,返回的是[op_1,op_2]
。Tensor_3
的op属性指向的是op_1
。op_1
的inputs属性指向的是[Tensor_1,Tensor_2]
,op_1
的output属性指向的是[Tensor_3]
。
有了Tensor与Operation对应在图中的关联关系,就可以将网络结构给画出来。
2 提取pb文件中的网络结构图
pb文件是将模型参数固化到图文件中,并合并了一些基础计算和删除了反向传播相关计算得到的protobuf协议文件。如果读者还不懂如何将CKPT模型文件转pb文件,请参考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1节部分。有了pb模型文件后,接下来是加载模型,加载pb模型示例代码如下所示。
def read_graph_from_pb(tf_model_path ,input_names,output_name):
with open(tf_model_path, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
gdef = tf.GraphDef()
gdef.ParseFromString(serialized)
with tf.Graph().as_default() as g:
tf.import_graph_def(gdef, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
其中,倒数第2行调用到的函数get_ops_from_pb()
用于获取网络结构图中指定输入节点和指定输出节点之间的计算节点。之所以要指定输入和输出,是为了将输入之前的计算节点(如加载数据队列等相关计算节点)和输出之后的计算节点(如计算loss等相关计算节点)去除,免得碍眼。函数get_ops_from_pb()
实现代码如下。
def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
if save_ori_network:
with open('ori_network.txt','w+') as w:
OPS=graph.get_operations()
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\\n')
inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
output_tf =graph.get_tensor_by_name(output_name)
OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] )
with open('network.txt','w+') as w:
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\\n')
OPS = sort_ops(OPS)
OPS = merge_layers(OPS)
return OPS
在裁剪网络结构(即只保留input_names和output_name之间节点)之前,先将原始的网络结构写入到ori_network.txt
中,文件中,每一行写入:输入Tensor---->op---->输出Tensor
。接下来调用函数get_ops_from_inputs_outputs
获取指定节点之间的节点。并调用sort_ops
函数对所有的节点排序,以保证被依赖的节点总是出现在相关节点之前。最后调用merge_layers
函数,将一些可以合并的计算合并成一个独立的节点,例如,Squeeze
计算相关节点合并成一个单独的Squeeze节点,又如const-->identity
两个计算节点可以直接忽略(即删除)。
注意:篇幅有限,这里不再将函数
get_ops_from_inputs_outputs
、sort_ops
、merge_layers
贴出,相关代码请前往文尾提供的源码地址中阅读。
3 绘制网络结构
考虑到SVG
绘制图形的简单易用优点,将排好序的网络计算节点和相关Tensor
对象数据以javascript
字符串的形式写入到html
中,使用<line>
标签绘制箭头,使用<rect>
标签绘制矩形,使用<ellipse>
标签绘制椭圆,使用<text>
标签显示文字。绘制类似于如下所示图像
注意:篇幅有限,这里不再介绍Javascript代码解析模型结构和SVG显示相关的原理,相关代码请前往文尾提供的源码地址中阅读。
4 测试模型显示
以《MobileNet V1官方预训练模型的使用》文中介绍的MobileNet V1网络结构为例,下载MobileNet_v1_1.0_192
文件并压缩后,得到mobilenet_v1_1.0_192_frozen.pb
文件。我们还需要知道mobilenet_v1_1.0_192_frozen.pb
模型对应的输入和输出Tensor
对象的名称,好在MobileNet_v1_1.0_192
压缩包中包含文件mobilenet_v1_1.0_192_info.txt
。通过该文件可知,输入Tensor
的名称为:input:0
,输出Tensor名称为:MobilenetV1/Predictions/Reshape_1:0
。有了这些信息后,调用函数read_graph_from_pb
得到静态图的节点列表对象ops,调用函数gen_graph(ops,"save/path/graph.html")
后,在目录save/path
中得到graph.html
文件,打开graph.html
后,显示结果如下。
显示网络结构分两种模式:合并模式和展开模式,分别如下图所示。
5 源码地址
https://github.com/huachao1001/CNNGraph
如果您觉得本文对您有帮助,或者想直接联系作者。欢迎关注我【Python学习实战】,每天学习一点点,每天进步一点点。
以上是关于从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)的主要内容,如果未能解决你的问题,请参考以下文章
代码解析深度学习系统编程模型:TensorFlow vs. CNTK
Google 第一个 TF 中文教学视频发布 | TensorFlow Lite 深度解析
如何从本地 JSON 文件解析数据并保存在模型类中并在 tableview 中使用