如何使用队列方法(没有 feed_dict)#tensorflow 在保存的模型上使用测试数据?
Posted
技术标签:
【中文标题】如何使用队列方法(没有 feed_dict)#tensorflow 在保存的模型上使用测试数据?【英文标题】:How to use Test data on saved model with queue approach (without feed_dict) #tensorflow? 【发布时间】:2017-11-13 16:44:43 【问题描述】:我是张量流的新手。我已经为mnist图像分类构建了一个convonet,如下所示我正在使用队列从磁盘批处理中读取图像(png)并将其传递给训练操作(我现在对此很满意)在训练之前一切都很好,我正在评估我在训练时以一定数量的步数运行准确度。
我正在使用 Saver 对象保存模型,并且可以看到正在写入磁盘的元和检查点文件。
现在真正的挑战是在模型完成训练后恢复模型并将其用于预测新图像
我的图表中的第一步(训练)如下所示,它采用 x_image(来自训练队列的图像)h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
由于我没有使用提要字典方法,我不能只使用保护程序恢复准确性并传递新数据。我必须为测试数据定义队列并重建图形(与之前完全相同),参考 x_image 更改为指向测试数据队列。
我现在如何在训练时恢复学习到的权重,并将其与这个新图表一起使用,以简单地运行我的预测/准确度操作。
我试图跟随 - https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py 教程,但被 eval 代码迷路了。
此外,如果我在我的训练图中添加一个虚拟常量,然后尝试检索它的值,我就可以检索它。
请任何 1 帮忙。谢谢
【问题讨论】:
我可以使用 saver.restore() 并恢复图形的变量。小心我没有运行 tf.global_variables_initializer() 以便变量/权重不会重新初始化,而是从保存的模型中恢复。我现在观察到的唯一奇怪的事情是我的预测操作为同一输入图像返回不同的标签。我正在使用 tf.train.shuffle_batch() 生成测试样本。任何人都可以指出我的错误。谢谢 【参考方案1】:好的,所以我找到了答案。
最初的挑战是在训练和验证阶段使用队列时在训练和测试数据之间切换。 现在由于队列是图结构的一部分,我们不能简单地修改它们。
我发现一篇文章使用 tf.case 在训练队列和测试队列之间切换,但我无法同时使用 shuffle batch。
手头的真正任务是在训练后保存模型,并在生产中使用保存的模型进行预测。
所以这是流程:
培训
创建一个创建图形的方法(将图像张量作为 输入)。 通过传递训练图像批次构建训练图 执行训练并使用保护对象保存模型。评估
现在使用测试图像批次重建相同的图表。 在会话中使用 saver 对象恢复权重(注意您不需要传递要恢复的变量,默认情况下它会恢复所有可恢复的变量) 此时不要运行 gloabl 变量初始化程序 运行您的预测操作(从新构建的图表生成)还要确保关闭评估中的退出功能,因为它会不断改变相同输入的输出
下面是伪代码
train_op, y_predict, accuracy = create_graph(train_input, train_label)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
model_saver = tf.train.Saver()
for i in range(2000):
if i%100 == 0:
train_accuracy = sess.run(accuracy)
print("step %d, training accuracy %f" %(i, train_accuracy))
sess.run(train_op)
print(sess.run(accuracy))
model_saver.save(sess, 'model/simple_model', global_step=100)
coord.request_stop()
coord.join(threads)
用于评估
_, y_predict, accuracy = create_graph(test_input, test_label)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint("./model/"))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
label_predict = sess.run([y_predict])
【讨论】:
以上是关于如何使用队列方法(没有 feed_dict)#tensorflow 在保存的模型上使用测试数据?的主要内容,如果未能解决你的问题,请参考以下文章