关闭会话后运行 TensorFlow 模型测试数据
Posted
技术标签:
【中文标题】关闭会话后运行 TensorFlow 模型测试数据【英文标题】:Running Tensorflow model test data after closing session 【发布时间】:2019-01-06 06:23:23 【问题描述】:我有一个我正在尝试复制的 Convnet(不是我的原始代码),只有当我在同一个座位上进行训练和测试时,它才能将测试数据集运行到经过训练的模型中。我只调整了几行代码,让它在说坐之后运行测试数据,所以我不确定会发生什么。我注意到“logits_out”是数据流边缘而不是张量板中的节点,所以是因为边缘没有自动保存在检查点中,以及它没有被保存为节点或任何其他形式的事实故意在原始代码,第一次坐完后不能调用它? 这是训练阶段的一般结构:
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
with tf.name_scope('1st_pool'):
#first layer
#subsequent layers
with graph.as_default():
#flattening, dropout, optimization, etc...
#some summary.scalar for loss analyses
logits_out = tf.layers.dense(flat, 1) #flat is the flattened array
saved_1 = tf.train.Saver()
trained_event = tf.summary.FileWriter('./CNN/train', graph=graph)
test_event = tf.summary.FileWriter('./CNN/test', graph=graph)
merged = tf.summary.merge_all()
with tf.Session(graph=graph) as sess:
#training and "validating"
sess.run(tf.global_variables_initializer())
#running train summaries
if step = test_round:
#running test summaries
saved_1.save(sess, './CNN/model_1.ckpt')
(已编辑:代码粘贴错误) 此代码在图形仍然打开的连续坐着期间成功运行:
with tf.Session(graph=graph) as sess:
saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
#
pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
#
几乎只调整了 2 行(如下所示) 以在第二天将元文件加载到新图表中,但当我尝试运行时出现错误“名称 'logits_out' 未定义”在一个单独的位置(事实上,我尝试 sess.run 的其他变量给出了相同的错误):
with tf.Session(graph=tf.get_default_graph()) as sess:
saved_1 = tf.train.import_meta_graph('./CNN/model_1.ckpt.meta')
saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
#
已编辑:我认为这可能是因为我错过了一个范围 - 或者误解了 tensorflow 如何命名东西 - 在第二天恢复会话/图表后,但我看不到如何 - 唯一的事情是命名为池。
【问题讨论】:
【参考方案1】:我今天可以通过运行这部分代码来创建图表,从而通过模型运行数据:
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
with tf.name_scope('1st_pool'):
#first layer
#subsequent layers
with graph.as_default():
#flattening, dropout, optimization, etc...
#some summary.scalar for loss analyses
logits_out = tf.layers.dense(flat, 1) #flat is the flattened array
saved_1 = tf.train.Saver()
trained_event = tf.summary.FileWriter('./CNN/train', graph=graph)
test_event = tf.summary.FileWriter('./CNN/test', graph=graph)
merged = tf.summary.merge_all()
with tf.Session(graph=graph) as sess:
#training and "validating"
sess.run(tf.global_variables_initializer())
#running train summaries
if step = test_round:
#running test summaries
saved_1.save(sess, './CNN/model_1.ckpt')
然后运行
没有经过编辑的两行代码:
with tf.Session(graph=graph) as sess:
saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
#
pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
#
所以所有关于 SO 的整个帖子的要点是我不必使用 tf.train.import_meta_graph,但我不明白 tf.train.import_meta_graph 的用途是什么?我认为它会导入图形并将其元数据保存在“.meta”文件中,这样我就可以避免从源代码重建图形? (注意:我会在弄清楚后删除这个后记问题)
【讨论】:
以上是关于关闭会话后运行 TensorFlow 模型测试数据的主要内容,如果未能解决你的问题,请参考以下文章