TensorFlow学习:模型的保存与恢复(上)基本操作
Posted 谢小小XH
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow学习:模型的保存与恢复(上)基本操作相关的知识,希望对你有一定的参考价值。
更新:
2018.5.4 补充模型保存和恢复的原理,补充了模型保存和恢复的一般流程
版本:tensorflow 1.8
前面一直说的都是没有涉及到模型的保存.一般深度学习的训练是很需要时间的,不可能程序退出了然后又重新训练一次,所以训练好的模型需要保存下来,方便之后的再训练或者是把模型分享给别人都是可以的.模型的保存也可以叫做持久化,一个意思.接下来不啰嗦了,用一个简单的例子来说说模型怎么保存.
本节的所有代码可以在我的GitHub找到:
一.常见类和函数
在这部分先把模型保存和恢复中的常见类和函数列出来,可以暂时先不用详细看他们是怎么用的,这里先混个眼熟。
Ⅰ.tf.train.Saver
保存模型最基本的类就是这个类啦,所以一旦涉及到保存模型的需求,这个类是不可避免的。Saver
类提供了很多方便的操作能够从checkpoints(检查点)保存和恢复变量,
Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver.
Saver类能够通过给定的计数器自动为checkpoint文件编号,这能够在训练模型的不同阶段保存多个检查点文件,举个栗子,你能够使用训练轮数来为你的检查点文件编号,为了避免占满硬盘,Saver还能够自动管理检查点文件,比如保留最新的那N个文件等等。
Saver类提供了一些函数来进行模型的保存和恢复,这里按照平时使用的频率来排序,列出常见的类方法。
save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix=’meta’,write_meta_graph=True,write_state=True,strip_default_attrs=False)
作用就是保存变量,
参数:
sess: 运行当前图的session
save_path: 检查点(checkpoint)文件名
global_step: If provided the global step number is appended to save_path to create the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer.
latest_filename: 可选名称,表示最新的检查点(checkpoints),默认为’checkpoint’.
meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to ‘meta’.
write_meta_graph: Boolean indicating whether or not to write the meta graph file.
write_state: Boolean indicating whether or not to write the CheckpointStateProto.
strip_default_attrs: Boolean. If True, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see Stripping Default-Valued Attributes.
返回:
检查点文件保存的路径。 If the saver is sharded, this string ends with: ‘-?????-of-nnnnn’ where ‘nnnnn’ is the number of shards created. If the saver is empty, returns None.
restore(sess,save_path)
恢复变量,同时要求图运行在这个session里面。
参数:
sess: 用来恢复参数的session
save_path: 保存模型的地址,一般来说,常常使用save()
函数返回的地址后者使用latest_checkpoint()
来得到地址。
Ⅱ.tf.train.latest_checkpoint
tf.train.latest_checkpoint(checkpoint_dir,latest_filename=None)
找到最新保存的checkpoint文件的文件名
参数:
checkpoint_dir: 变量保存的目录
latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames. See the corresponding argument to Saver.save().
Returns:
The full path to the latest checkpoint or None if no checkpoint was found.
Ⅲ.tf.train.export_meta_graph
tf.train.export_meta_graph(filename=None,meta_info_def=None,graph_def=None,saver_def=None,collection_list=None,as_text=False,graph=None,export_scope=None,clear_devices=False,clear_extraneous_savers=False,strip_default_attrs=False,**kwargs)
Returns MetaGraphDef proto. Optionally writes it to filename.
This function exports the graph, saver, and collection objects into MetaGraphDef protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph.
Args:
filename: Optional filename including the path for writing the generated MetaGraphDef protocol buffer.
meta_info_def: MetaInfoDef protocol buffer.
graph_def: GraphDef protocol buffer.
saver_def: SaverDef protocol buffer.
collection_list: List of string keys to collect.
as_text: If True, writes the MetaGraphDef as an ASCII proto.
graph: The Graph to export. If None, use the default graph.
export_scope: Optional string. Name scope under which to extract the subgraph. The scope name will be striped from the node definitions for easy import later into new name scopes. If None, the whole graph is exported. graph_def and export_scope cannot both be specified.
clear_devices: Whether or not to clear the device field for an Operation or Tensor during export.
clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef.
strip_default_attrs: Boolean. If True, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see Stripping Default-Valued Attributes.
**kwargs: Optional keyed arguments.
Returns:
A MetaGraphDef proto.
Raises:
ValueError: When the GraphDef is larger than 2GB.
RuntimeError: If called with eager execution enabled.
Eager Compatibility
Exporting/importing meta graphs is not supported. No graph exists when eager execution is enabled.
Ⅳ.tf.train.import_meta_graph
tf.train.import_meta_graph(meta_graph_or_file,clear_devices=False,import_scope=None,**kwargs)
作用是把MetaGraphDef proto
中存储的图重新创建出来。这个函数使用MetaGraphDef protocol buffer
作为输入,如果当前参数是一个包含 MetaGraphDef protocol buffer
的文件, 那么它会从文件内容中构建一个protocol buffer
然后把graph_def
域中所有的结点添加到当前图,重新构建所有的集合,同时返回一个从 saver_def
中构建的saver对象。
一般这个函数结合export_meta_graph()
函数来使用。用作对于以保存的图的恢复。
这里举个栗子
# Create a saver.
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# Saves checkpoint, which by default also exports a meta_graph
# named 'my-model-global_step.meta'.
saver.save(sess, 'my-model', global_step=step)
Later we can continue training from this saved meta_graph without building the model from scratch.
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() returns a list. In this example we only want the
# first one.
train_op = tf.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)
NOTE: Restarting training from saved meta_graph only works if the device assignments have not changed.
参数:
meta_graph_or_file:
MetaGraphDef protocol buffer
或者文件名,简单起见,我们平时其实就是使用保存模型时候得到的model.ckpt.meta
文件。
clear_devices: Whether or not to clear the device field for an Operation or Tensor during import.
import_scope: Optional string. Name scope to add. Only used when initializing from protocol buffer.
**kwargs: Optional keyed arguments.
返回:
saver对象,要是没有任何变量被恢复,那么会返回一个None
Ⅴ.get_tensor_by_name
get_tensor_by_name(name)
作用是用给定的名字返回对应的Tensor,(这个函数能够同时被多个线程调用)
参数:
name: 指定的名称
Ⅵ
Ⅶ
二.简单示例
这里可以先暂时不用知道原理,只知道这么做的过程就行,在后面会详细的分析这个例子的原理.
Ⅰ.保存
import tensorflow as tf
import numpy as np
#graph
graph=tf.Graph()
with graph.as_default():
a=tf.Variable(initial_value=[[1,2],[3,4]],dtype=tf.float32,name="a")
b = tf.Variable(initial_value=[[1, 1], [1, 1]], dtype=tf.float32, name="b")
c=a+b
cons=tf.constant(value=[1,2,3,4,5],name="cons")
init_op=tf.global_variables_initializer()
#Saver class
saver=tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(init_op)
print("c:\\n",sess.run(c))
print("cons:\\n",sess.run(cons))
#save model
path=saver.save(sess=sess,save_path="./model.ckpt")
print("path:",path)
运行结果:
c:
[[2. 3.]
[4. 5.]]
cons:
[1 2 3 4 5]
path: ./model.ckpt
同时相应的文件夹下面会生成几个模型文件。在我这个版本,这几个文件的名字分别为:checkpoint
,model.ckpt.data-00000-of-00001
, model.ckpt.index
,model.ckpt.meta
, 也就是说,虽然在save函数里面只是指定了路径为model.ckpt
,但是系统真正在保存的时候会在后面分别加上后缀保存为几个文件,这几个文件都是模型保存相关的文件,但是文件各自的作用是不同的。下面详细介绍一下每个文件的作用。
model.ckpt.meta:这个文件主要是保存计算图的结构,可以简单理解为网络的结构就行。
Ⅱ.恢复
import tensorflow as tf
import numpy as np
#graph
graph=tf.Graph()
with graph.as_default():
v1=tf.Variable(initial_value=tf.ones(shape=(2,2)),dtype=tf.float32,name="a")
v2 = tf.Variable(initial_value=tf.ones(shape=(2,2)), dtype=tf.float32, name="b")
v=v1+v2
cons=tf.constant(value=[2,3,4,5],name="cons")
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session(graph=graph) as sess:
#sess.run(init_op)
#restore model
saver.restore(sess=sess,save_path="./model.ckpt")
print("c:\\n",sess.run(v))
print("cons:\\n",sess.run(cons))
在恢复的代码中,图中的变量什么的都差不多(和保存模型来对比),但是这段代码中没有变量的初始化过程,这里需要注意的是,变量的值是通过已经保存的模型加载进来。变量名不需要一模一样,但是名字"name"
需要一样,变量的初始值形状一样就行. 也就是说,最后从保存的模型中恢复的时候,是按照name参数的名字来对应找的.
然而这样的方式是重复定义了计算图上面的基本运算。你必须定义和原来的一 样的代码才能够得到存储的东西,使用非常受限制,在一些简单的地方使用这个方式是很好的。
还有一种方法就是不重新定义图的运算,直接加载已经持久化的图。这种方法更加灵活,但是也有点小复杂.
Ⅲ.直接加载的方式
上面的恢复需要重复定义一些操作,这里也有直接把整个图加载的方法。就是使用tf.train.import_meta_graph
函数来载入
import numpy as np
import tensorflow as tf
#import meta graph
saver=tf.train.import_meta_graph(meta_graph_or_file="./model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess=sess,save_path="./model.ckpt")
#get default graph
graph=tf.get_default_graph()
print(graph)
#get tensor
tensor_a = graph.get_tensor_by_name(name="a:0")
print(tensor_a)
print(sess.run(tensor_a))
结果:
<tensorflow.python.framework.ops.Graph object at 0x000001A7E9B1C278>
Tensor("a:0", shape=(2, 2), dtype=float32_ref)
[[1. 2.]
[3. 4.]]
以上是关于TensorFlow学习:模型的保存与恢复(上)基本操作的主要内容,如果未能解决你的问题,请参考以下文章
机器学习与Tensorflow——tf.train.Saver()inception-v3的应用