make_initializable_iterator 和 make_one_shot_iterator 之间的 tensorflow 数据集 API 差异

Posted

技术标签:

【中文标题】make_initializable_iterator 和 make_one_shot_iterator 之间的 tensorflow 数据集 API 差异【英文标题】:tensorflow Dataset API diff between make_initializable_iterator and make_one_shot_iterator 【发布时间】:2018-06-13 23:20:42 【问题描述】:

我想知道make_initializable_iteratormake_one_shot_iterator之间的区别。 1.Tensorflow 文档说A "one-shot" iterator does not currently support re-initialization. 到底是什么意思? 2.以下2个sn-ps是否等价? 使用make_initializable_iterator

iterator = data_ds.make_initializable_iterator()
data_iter = iterator.get_next()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for e in range(1, epoch+1):
    sess.run(iterator.initializer)
    while True:
        try:
            x_train, y_train = sess.run([data_iter])
            _, cost = sess.run([train_op, loss_op], feed_dict=X: x_train,
                                                               Y: y_train)
        except tf.errors.OutOfRangeError:   
            break
sess.close()

使用make_one_shot_iterator

iterator = data_ds.make_one_shot_iterator()
data_iter = iterator.get_next()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for e in range(1, epoch+1):
    while True:
        try:
            x_train, y_train = sess.run([data_iter])
            _, cost = sess.run([train_op, loss_op], feed_dict=X: x_train,
                                                               Y: y_train)
        except tf.errors.OutOfRangeError:   
            break
sess.close()

【问题讨论】:

【参考方案1】:

假设您想使用相同的代码进行训练和验证。您可能希望使用相同的迭代器,但初始化为指向不同的数据集;类似于以下内容:

def _make_batch_iterator(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    ...
    return dataset.make_initializable_iterator()


filenames = tf.placeholder(tf.string, shape=[None])
iterator = _make_batch_iterator(filenames)

with tf.Session() as sess:
    for epoch in range(num_epochs):

        # Initialize iterator with training data
        sess.run(iterator.initializer,
                 feed_dict=filenames: ['training.tfrecord'])

        _train_model(...)

        # Re-initialize iterator with validation data
        sess.run(iterator.initializer,
                 feed_dict=filenames: ['validation.tfrecord'])

        _validate_model(...)

使用一次性迭代器,您不能像这样重新初始化它。

【讨论】:

你能解释一下可初始化和可重新初始化的迭代器有什么区别吗? 我们可以“重新加载”数据集吗?例如:特征、标签 = trainDataset.make_one_shot_iterator().get_next()、图 = fn(特征、标签)。然后训练完,重新加载特征,labels = TestDataset.xxx().get_next?因为我认为它是一个不同的数据集,而不是重新初始化 @Leighton 这可能意味着您创建了额外的数据集图,这不是您通常想要的

以上是关于make_initializable_iterator 和 make_one_shot_iterator 之间的 tensorflow 数据集 API 差异的主要内容,如果未能解决你的问题,请参考以下文章