关闭会话后运行 TensorFlow 模型测试数据

Posted

技术标签:

【中文标题】关闭会话后运行 TensorFlow 模型测试数据【英文标题】:Running Tensorflow model test data after closing session 【发布时间】:2019-01-06 06:23:23 【问题描述】:

我有一个我正在尝试复制的 Convnet(不是我的原始代码),只有当我在同一个座位上进行训练和测试时,它才能将测试数据集运行到经过训练的模型中。我只调整了几行代码,让它在说坐之后运行测试数据,所以我不确定会发生什么。我注意到“logits_out”是数据流边缘而不是张量板中的节点,所以是因为边缘没有自动保存在检查点中,以及它没有被保存为节点或任何其他形式的事实故意在原始代码,第一次坐完后不能调用它? 这是训练阶段的一般结构:

tf.reset_default_graph()
graph = tf.Graph()

with graph.as_default():
    with tf.name_scope('1st_pool'):
        #first layer
#subsequent layers

with graph.as_default():
    #flattening, dropout, optimization, etc...
    #some summary.scalar for loss analyses
    logits_out = tf.layers.dense(flat, 1) #flat is the flattened array

    saved_1 = tf.train.Saver()
    trained_event = tf.summary.FileWriter('./CNN/train', graph=graph)

    test_event = tf.summary.FileWriter('./CNN/test', graph=graph)

    merged = tf.summary.merge_all()

with tf.Session(graph=graph) as sess:
    #training and "validating"
    sess.run(tf.global_variables_initializer())
    #running train summaries

    if step = test_round:
        #running test summaries
        saved_1.save(sess, './CNN/model_1.ckpt')

(已编辑:代码粘贴错误) 此代码在图形仍然打开的连续坐着期间成功运行:

with tf.Session(graph=graph) as sess:

    saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
    #
    pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
    #

几乎只调整了 2 行(如下所示) 以在第二天将元文件加载到新图表中,但当我尝试运行时出现错误“名称 'logits_out' 未定义”在一个单独的位置(事实上,我尝试 sess.run 的其他变量给出了相同的错误):

with tf.Session(graph=tf.get_default_graph()) as sess:
    saved_1 = tf.train.import_meta_graph('./CNN/model_1.ckpt.meta')
    saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
    pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
    #

已编辑:我认为这可能是因为我错过了一个范围 - 或者误解了 tensorflow 如何命名东西 - 在第二天恢复会话/图表后,但我看不到如何 - 唯一的事情是命名为池。

【问题讨论】:

【参考方案1】:

我今天可以通过运行这部分代码来创建图表,从而通过模型运行数据:

tf.reset_default_graph()
graph = tf.Graph()

with graph.as_default():
    with tf.name_scope('1st_pool'):
        #first layer
#subsequent layers

with graph.as_default():
    #flattening, dropout, optimization, etc...
    #some summary.scalar for loss analyses
    logits_out = tf.layers.dense(flat, 1) #flat is the flattened array

    saved_1 = tf.train.Saver()
    trained_event = tf.summary.FileWriter('./CNN/train', graph=graph)

    test_event = tf.summary.FileWriter('./CNN/test', graph=graph)

    merged = tf.summary.merge_all()

with tf.Session(graph=graph) as sess:
    #training and "validating"
    sess.run(tf.global_variables_initializer())
    #running train summaries

    if step = test_round:
        #running test summaries
        saved_1.save(sess, './CNN/model_1.ckpt')

然后运行

没有经过编辑的两行代码:

with tf.Session(graph=graph) as sess:

    saved_1.restore(sess, tf.train.latest_checkpoint('./CNN'))
    #
    pred = sess.run(logits_out, feed_dict=some inputs for placeholders)
    #

所以所有关于 SO 的整个帖子的要点是我不必使用 tf.train.import_meta_graph,但我不明白 tf.train.import_meta_graph 的用途是什么?我认为它会导入图形并将其元数据保存在“.meta”文件中,这样我就可以避免从源代码重建图形? (注意:我会在弄清楚后删除这个后记问题)

【讨论】:

以上是关于关闭会话后运行 TensorFlow 模型测试数据的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 运行模型--会话(Session)

Tensorflow运行模型——会话

《Tensorflow技术解析与实战》第四章

如何使用 django 保持 tensorflow 会话在内存中运行

由于张量数据类型和形状Tensorflow,运行会话失败

推荐阅读 | 如何让TensorFlow模型运行提速36.8%(续)