如何使用 tf.MonitoredTrainingSession 在训练和验证数据集之间切换?

Posted

技术标签:

【中文标题】如何使用 tf.MonitoredTrainingSession 在训练和验证数据集之间切换?【英文标题】:How to switch between training and validation dataset with tf.MonitoredTrainingSession? 【发布时间】:2018-08-12 05:53:58 【问题描述】:

我想在 tensorflow Dataset API 中使用feedable 迭代器设计,这样我可以在一些训练步骤后切换到验证数据。但如果我切换到验证数据,它将结束整个会话。

以下代码演示了我想要做什么:

import tensorflow as tf


graph = tf.Graph()
with graph.as_default():
    training_ds = tf.data.Dataset.range(32).batch(4)
    validation_ds = tf.data.Dataset.range(8).batch(4)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_ds.output_types, training_ds.output_shapes)
    next_element = iterator.get_next()

    training_iterator = training_ds.make_initializable_iterator()
    validation_iterator = validation_ds.make_initializable_iterator()


with graph.as_default():

    with tf.train.MonitoredTrainingSession() as sess:
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        sess.run(training_iterator.initializer)
        count_training = 0
        while not sess.should_stop():
            x = sess.run(next_element, feed_dict=handle: training_handle)
            count_training += 1
            print(' [training] '.format(count_training, x.shape))
            # print(x)

            # we do periodic validation
            if count_training % 4 == 0:
                sess.run(validation_iterator.initializer)
                count_validation = 0
                while not sess.should_stop():
                    y = sess.run(next_element, feed_dict=handle: validation_handle)
                    count_validation += 1
                    print('   [validation] '.format(count_validation, y.shape))
                    # print(y)

训练数据有32个元素,用4个batch,所以得到8个batch 我们每 4 步进行一次验证,所以我希望:

#  1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
#      1 [validation]
#      2 [validation]

但在第一次验证完成时它会停止:

# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]

那么,如何在tf.MonitoredTrainingSession 中使用这个feedable 迭代器?

【问题讨论】:

【参考方案1】:

我建议在验证数据集末尾捕获 tf.errors.OutOfRangeError (您也可以在官方 API 中查看 the processing multiple epochs section 以获得使用 repeat 数据集的另一个解决方案):

while not sess.should_stop():
    x = sess.run(next_element, feed_dict=handle: training_handle)
    count_training += 1
    print(' [training] '.format(count_training, x.shape))

    # we do periodic validation
    if count_training % 4 == 0:
        sess.run(validation_iterator.initializer)
        count_validation = 0
        while True:
            try:
                y = sess.run(next_element, feed_dict=handle: validation_handle)
                count_validation += 1
                print('   [validation] '.format(count_validation, y.shape))
            except tf.errors.OutOfRangeError:
                break

这段代码打印:

1 [training] (4,)  
2 [training] (4,)  
3 [training] (4,)  
4 [training] (4,)  
  1 [validation] (4,)  
  2 [validation] (4,)  
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
  1 [validation] (4,)
  2 [validation] (4,)

【讨论】:

这真的很有趣,这意味着,在内部,sess.should_stop() 只需检查数据集管道的状态以跳过抛出该异常,如果您自己捕获该异常,则状态会变回有效再次。 发生的事情是验证dataset 引发tf.errors.OutOfRangeErrorMonitoredTrainingSession 捕获它并自行停止。

以上是关于如何使用 tf.MonitoredTrainingSession 在训练和验证数据集之间切换?的主要内容,如果未能解决你的问题,请参考以下文章

如何使用本机反应创建登录以及如何验证会话

如何在自动布局中使用约束标识符以及如何使用标识符更改约束? [迅速]

如何使用 AngularJS 的 ng-model 创建一个数组以及如何使用 jquery 提交?

如何使用laravel保存所有行数据每个行名或相等

如何使用 Math.Net 连接矩阵。如何使用 Math.Net 调用特定的行或列?

WSARecv 如何使用 lpOverlapped?如何手动发出事件信号?