Tensorflow saver.restore()不恢复网络

Posted

技术标签:

【中文标题】Tensorflow saver.restore()不恢复网络【英文标题】:Tensorflow saver.restore() not restoring network 【发布时间】:2017-11-06 18:45:06 【问题描述】:

我完全迷失了 tensorflow saver 方法。

我正在尝试学习基本的 tensorflow 深度神经网络模型教程。我想弄清楚如何训练网络进行几次迭代,然后在另一个会话中加载模型。

with tf.Session() as sess:
    graph = tf.Graph()
    x = tf.placeholder(tf.float32,shape=[None,784])
    y_ = tf.placeholder(tf.float32, shape=[None,10])

    sess.run(global_variables_initializer())

    #Define the Network
    #(This part is all copied from the tutorial - not copied for brevity)
    #See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/

跳过训练。

    #Train the Network
    train_step = tf.train.AdamOptimizer(1e-4).minimize(
                     cross_entropy,global_step=global_step)
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    saver = tf.train.Saver()

    for i in range(101):
        batch = mnist.train.next_batch(50)
        if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict=
                           x:batch[0],y_:batch[1])
        print 'Step %d, training accuracy %g'%(i,train_accuracy)
            train_step.run(feed_dict=x:batch[0], y_: batch[1])
        if i%100 == 0:
            print 'Test accuracy %g'%accuracy.eval(feed_dict=x: 
                       mnist.test.images, y_: mnist.test.labels)

        saver.save(sess,'./mnist_model')

控制台打印出来:

第 0 步,训练精度 0.16

测试精度 0.0719

Step 100,训练精度 0.88

测试精度 0.8734

接下来我要加载模型

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('mnist_model.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    sess.run(tf.global_variables_initializer())

现在我想重新测试一下是否加载了模型

print 'Test accuracy %g'%accuracy.eval(feed_dict=x: 
                       mnist.test.images, y_: mnist.test.labels)

控制台打印出来:

测试精度 0.1151

模型似乎没有保存任何数据?我做错了什么?

【问题讨论】:

你不应该在恢复权重后运行sess.run(tf.global_variables_initializer())。这将重置您的所有权重 【参考方案1】:

当您保存模型时,通常所有全局变量都保存在外部文件中,而局部变量则不是。你可以看看这个answer 来了解其中的区别。

恢复代码中的错误是调用tf.global_variable_initializer() 之后 saver.restore()saver.restore 文档提到,

要恢复的变量不必已经初始化,因为恢复本身就是一种初始化变量的方法。

因此,请尝试删除该行,

sess.run(tf.global_variables_initializer())

理想情况下,您应该将其替换为,

sess.run(tf.local_variables_initializer())

【讨论】:

谢谢,这似乎确实解决了我的问题!如果文件声明saver.restore() 是一个初始化过程,那么sess.run(tf.local_variables_initializer()) 是否有任何用途?这似乎也表明 A quick complete tutorial to save and restore Tensorflow models 之类的教程显示不正确的用法,不是吗? 你应该检查tf.local_variables()。如果此列表非空,则需要它

以上是关于Tensorflow saver.restore()不恢复网络的主要内容,如果未能解决你的问题,请参考以下文章

求助 tensorflow怎样恢复预训练的模型啊

在张量流检查点中修改张量的形状

深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

TensorFlow报错:ValueError The passed save_path is not a valid checkpoint

TensorFlow报错:ValueError The passed save_path is not a valid checkpoint

人工智能实践:全连接网络实践