tensorflow模型导出

Posted zbxzc

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow模型导出相关的知识,希望对你有一定的参考价值。


tf.train.Saver类负责保存和还原神经网络
自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构 
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表






Tensorflow : What is the relationship between .ckpt file and .ckpt.meta and .ckpt.index , and .pb file

tensorflow模型保存文件分析

Tensorflow 模型持久化

查看tensorflow ckpt文件中的变量名和对应值

TensorFlow模型保存和提取方法

TensorFlow模型文件保存和读取

tensorflow保存部分变量

python下tensorflow模型的导出

tensorflow 中导出/恢复模型Graph数据Saver

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md

https://stackoverflow.com/questions/35508866/tensorflow-different-ways-to-export-and-run-graph-in-c/43639305#43639305

https://github.com/anandanand84/tensorflow_model_server

https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow/38853802#

export



『TensorFlow』模型载入方法汇总


import tensorflow as tf
from tensorflow.python.platform import gfile

# 这是从二进制格式的pb文件加载模型
graph = tf.get_default_graph()
graphdef = graph.as_graph_def()
graphdef.ParseFromString(gfile.FastGFile("/data/TensorFlowandroidMNIST/app/src/main/expert-graph.pb", "rb").read())
_ = tf.import_graph_def(graphdef, name="")

#这是从meta文件加载模型
_ = tf.train.import_meta_graph("model.ckpt.meta")

summary_write = tf.summary.FileWriter("/data/TensorFlowAndroidMNIST/logdir" , graph)




freeze_graph

tensorflow,使用freeze_graph.py将模型文件和权重数据整合在一起并去除无关的Op

tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测

算法移植优化(六)tensorflow模型移植推理优化

Tensorflow 训练模型数据freeze固话保存在Graph中




将TensorFlow的网络导出为单个文件

将tensorflow训练好的模型移植到android

Steps to reproduce freeze_graph



tf.train.write_graph
tf.train.write_graph()保存模型,它只是保存了模型的结构,并不保存训练完毕的参数值


tf.train.saver()保存模型,将网络中的参数值与模型的结构分开存储
tf.train.Saver函数保存模型文件的时候,是保存所有的参数信息,而有些时候我们并不需要所有的参数信息。我们只需要知道神经网络的输入层经过前向传播计算得到输出层即可,所以在保存的时候,我们也不需要保存所有的参数,以及变量的初始化、模型保存等辅助节点信息与迁移学习类似。之前使用tf.train.Saver函数保存模型文件的时候会产生多个文件,它将变量的取值和计算图结构分成了不同的文件存储。TensorFlow提供了另一种保存模型文件的方法,将计算图保存在一个文件中


写入pb文件

TensorFlow的convert_variables_to_constants函数

如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件


graph_util.convert_variables_to_constants可以把整个sesion当作常量都保存下来,通过output_node_names参数来指定输出

tf.gfile.FastGFile('model/cxq.pb', mode='wb')指定保存文件的路径以及读写方式

f.write(output_graph_def.SerializeToString())将固化的模型写入到文件



模型保存

import tensorflow as tf  
from tensorflow.python.framework import graph_util  
from tensorflow.python.platform import gfile  
  
if __name__ == "__main__":  
    a = tf.Variable(tf.constant(5.,shape=[1]),name="a")  
    b = tf.Variable(tf.constant(6.,shape=[1]),name="b")  
    c = a + b  
    init = tf.initialize_all_variables()  
    sess = tf.Session()  
    sess.run(init)  
    #导出当前计算图的GraphDef部分  
    graph_def = tf.get_default_graph().as_graph_def()  
    #保存指定的节点,并将节点值保存为常数  
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])  
    #将计算图写入到模型文件中  
    model_f = tf.gfile.GFile("model.pb","wb")  
    model_f.write(output_graph_def.SerializeToString()) 

convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存的时候,通过convert_variables_to_constants函数来指定保存的节点名称而不是张量的名称,“add:0”是张量的名称而"add"表示的是节点的名称


模型读取

sess = tf.Session()  
   #将保存的模型文件解析为GraphDef  
   model_f = gfile.FastGFile("model.pb",'rb')  
   graph_def = tf.GraphDef()  
   graph_def.ParseFromString(model_f.read())  
   c = tf.import_graph_def(graph_def,return_elements=["add:0"])  
   print(sess.run(c))  
   #[array([ 11.], dtype=float32)]

在读取模型文件获取变量的值的时候,我们需要指定的是张量的名称而不是节点的名称




以上是关于tensorflow模型导出的主要内容,如果未能解决你的问题,请参考以下文章

模型导出与部署TensorFlow Client对接模型服务

从 Tensorflow 对象检测 API 动物园模型导出错误的冻结图

tensorflow模型导出

tensorflow Mobilenet 导出模型的方法

TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

尝试将我的keras模型导出到tensorflow服务时出错