第五章 MNIST数字识别问题

Posted 山本夏木

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第五章 MNIST数字识别问题相关的知识,希望对你有一定的参考价值。

4.1. ckpt文件保存方法

在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。 
这样保存下来的模型,会在其文件夹下生成三个文件,分别是: 
* .ckpt.meta文件,保存tensorflow模型的计算图结构。 
* .ckpt文件,保存计算图下所有变量的取值。 
* checkpoint文件,保存目录下所有模型文件列表。

技术分享图片
import tensorflow as tf
#保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")
#加载保存了两个变量和的模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-1.6226364]
#直接加载持久化的图。因为之前没有导出v3,所以这里会报错
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1) 
    print sess.run(v2) 
    print sess.run(v3) 
INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-0.81131822]
[-0.81131822]

# 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})
View Code

 

4.2.1 滑动平均类的保存

技术分享图片
import tensorflow as tf
#使用滑动平均
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables(): print variables.name
    
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables(): print variables.name
v:0
v:0
v/ExponentialMovingAverage:0

#保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print sess.run([v, ema.average(v)])
10.0, 0.099999905]

#加载滑动平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999
View Code

 

4.2.2 variables_to_restore函数的使用样例

import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print ema.variables_to_restore()

#等同于saver = tf.train.Saver(ema.variables_to_restore())
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
{u‘v/ExponentialMovingAverage‘: <tf.Variable ‘v:0‘ shape=() dtype=float32_ref>}

 

4.3. pb文件保存方法

#pb文件的保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [‘add‘])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
------------------------------------------------------------------------
#加载pb文件
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"
   
    with gfile.FastGFile(model_filename, ‘rb‘) as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print sess.run(result)

[array([ 3.], dtype=float32)]

张量的名称后面有:0,表示是某个计算节点的第一个输出,而计算节点本身的名称后是没有:0的。





以上是关于第五章 MNIST数字识别问题的主要内容,如果未能解决你的问题,请参考以下文章

手写数字识别——基于全连接层和MNIST数据集

pytorch学习实战第五篇:卷积神经网络实现MNIST手写数字识别

《Python深度学习》第五章-1(CNN简介)读书笔记

神经网络和深度学习笔记 - 第五章 深度神经网络学习过程中的梯度消失问题

人工智能实践:tensorflow笔记-第五周

基于WebGL实现矩阵计算