TensorFlow 存储与读取

Posted 郭老猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow 存储与读取相关的知识,希望对你有一定的参考价值。

之前通过CNN进行的MNIST训练识别成功率已经很高了,不过每次运行都需要消耗很多的时间。在实际使用的时候,每次都要选经过训练后在进行识别那就太不方便了。

所以我们学习一下如何将训练习得的参数保存起来,然后在需要用的时候直接使用这些参数进行快速的识别。

本章节代码来自《Tensorflow 实战Google深度学习框架》5.5 TensorFlow 最佳实践样例程序  针对书中的代码做了一点点的调整。

 

mnist_inference.py:

#coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight_variable(shape, regularizer):
    weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))
    if regularizer != None:
        tf.add_to_collection(losses, regularizer(weights))
    return weights

def inference(input_tensor, regularizer):
    with tf.variable_scope(layer1):
        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
        biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)

    with tf.variable_scope(layer2):
        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
        biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases

    return layer2

这里是向前传播的方法文件。这个方法在训练和测试的过程都需要用到,将它抽离出来既能使用起来更加方便,也能保证训练和测试时使用的方法保持一致。

get_variable

 weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))

源代码第十行使用get_variable函数获取变量。

在训练网络是会创建这些变量;

在测试时会通过训练时保存的模型加载这些变量的值。

 

(未完待续。。。。)

以上是关于TensorFlow 存储与读取的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow学习教程------tfrecords数据格式生成与读取

tensorflow学习笔记三:实例数据下载与读取

tensorflow的tfrecord操作代码与数据协议规范

十图详解TensorFlow数据读取机制(附代码)

TensorFlow学习笔记读取数据

TensorFlow 读取带有标签的图像