Tensorflow细节-P112-模型持久化

Posted liuboblog

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow细节-P112-模型持久化相关的知识,希望对你有一定的参考价值。

第一个代码

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")

看看看,就是上面:注意两个方面
(1)saver = tf.train.Saver()提前设定好
(2)saver.save(sess, "Saved_model/model.ckpt")这里面有sess要注意!

第二个代码

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)

这里有三个要注意的点
(1)上面定义好了模型(变量名字与第一个代码一样),Saver()里什么都没有
(2)saver.restore(sess, "Saved_model/model.ckpt")里有sess,ckpt是数据
(3)result是读取数据的结果,跟这里的变量没关系

第三个代码

import tensorflow as tf
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1) 
    print sess.run(v2) 
    print sess.run(v3) 

看这里,由于v3是一个变量,要输出的话需要先进行初始化(v1、v2不用)

下面,就是滑动平均了

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
    print(variables.name)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
        print(variables.name)
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0 ?v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print(sess.run([v, ema.average(v)]))

技术图片
从上面的代码和图片可以看到开始时是一个变量,后来经过maintain_averages_op = ema.apply(tf.global_variables())就多了一个影子变量,这样子,就把影子变量存好了

下面就是加载滑动平均的影子变量了

v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)

注意重命名

以上是关于Tensorflow细节-P112-模型持久化的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow细节-P190-输入文件队列

Tensorflow细节-P42张量的概念及使用

Tensorflow细节-P89-collection的使用

TensorFlow模型持久化

Tensorflow细节-P160-迁移学习

TensorFlow学习笔记--网络模型的保存和读取