恢复使用迭代器的 TensorFlow 模型

Posted

技术标签:

【中文标题】恢复使用迭代器的 TensorFlow 模型【英文标题】:Restoring a Tensorflow model that uses Iterators 【发布时间】:2018-04-05 16:01:47 【问题描述】:

我有一个使用迭代器训练我的网络的模型;遵循 Google 现在推荐的新数据集 API 管道模型。

我读取了 tfrecord 文件,向网络提供数据,训练得很好,一切顺利,我在训练结束时保存了我的模型,以便稍后在其上运行推理。代码的简化版本如下:

""" Training and saving """

training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
  training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
  next_training_element = training_iterator.get_next()
  training_init_op = training_iterator.make_initializer(training_dataset)

def train(num_epochs):
  # compute for the number of epochs
  for e in range(1, num_epochs+1):
    session.run(training_init_op) #initializing iterator here
    while True:
      try:
        images, labels = session.run(next_training_element)
        session.run(optimizer, feed_dict=x: images, y_true: labels)
      except tf.errors.OutOfRangeError:
        saver_name = './saved_models/ucf-model'
        print("Finished Training Epoch ".format(e))
        break



    """ Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')
saver.restore(session, tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()

# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser, num_threads=4, output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)

testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset

with tf.Session() as session:
  session.run(testing_init_op)
  while True:
    try:
      images, labels = session.run(next_testing_element)
      accuracy = session.run(accuracy, feed_dict=x: test_images, y_true: test_labels) #error here, x, y_true not defined
    except tf.errors.OutOfRangeError:
      break

我的问题主要是当我恢复模型时。如何将测试数据馈送到网络?

当我使用testing_iterator = graph.get_operation_by_name("iterators/Iterator")next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") 恢复我的迭代器时,我收到以下错误: GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element. 所以我确实尝试使用:testing_init_op = testing_iterator.make_initializer(testing_dataset)) 初始化我的数据集。我收到了这个错误:AttributeError: 'Operation' object has no attribute 'make_initializer'

另一个问题是,由于正在使用迭代器,因此无需在 training_model 中使用占位符,因为迭代器将数据直接提供给图形。但是这样,当我将数据提供给“准确性”操作时,如何恢复我在第 3 行到最后一行的 feed_dict 键?

编辑:如果有人可以建议一种在迭代器和网络输入之间添加占位符的方法,那么我可以尝试通过评估“准确性”张量来运行图形,同时将数据提供给占位符并完全忽略迭代器。

【问题讨论】:

【参考方案1】:

我建议看看 CheckpointInputPipelineHook CheckpointInputPipelineHook,它实现了保存迭代器状态以使用 tf.Estimator 进行进一步训练。

【讨论】:

【参考方案2】:

我建议使用专门为此目的设计的tf.contrib.data.make_saveable_from_iterator。它不那么冗长,并且不需要您更改现有代码,尤其是您定义迭代器的方式。

工作示例,当我们在第 5 步完成后保存所有内容时。请注意,我什至不知道使用的是什么种子。

import tensorflow as tf

iterator = (
  tf.data.Dataset.range(100)
  .shuffle(10)
  .make_one_shot_iterator())
batch = iterator.get_next(name='batch')

saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = tf.train.Saver()

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  for step in range(10):
    print(': '.format(step, sess.run(batch)))
    if step == 5:
      saver.save(sess, './foo', global_step=step)

# 0: 1
# 1: 6
# 2: 7
# 3: 3
# 4: 8
# 5: 10
# 6: 12
# 7: 14
# 8: 5
# 9: 17

然后,如果我们从第 6 步继续,我们会得到相同的输出。

import tensorflow as tf

saver = tf.train.import_meta_graph('./foo-5.meta')
with tf.Session() as sess:
  saver.restore(sess, './foo-5')
  for step in range(6, 10):
    print(': '.format(step, sess.run('batch:0')))
# 6: 12
# 7: 14
# 8: 5
# 9: 17

【讨论】:

【参考方案3】:

恢复已保存的元图时,可以恢复带有名称的初始化操作,然后再次使用它来初始化输入管道以进行推理。

即在创建图表时,可以这样做

    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

然后通过执行以下操作恢复此操作:

    dataset_init_op = graph.get_operation_by_name('dataset_init')

这是一个独立的代码sn-p,它比较了一个随机初始化模型在恢复之前和之后的结果。

保存迭代器

np.random.seed(42)
data = np.random.random([4, 4])
X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_next_op = iterator.get_next()

# name the operation
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

w = np.random.random([1, 4])
W = tf.Variable(w, name='W', dtype=tf.float32)
output = tf.multiply(W, dataset_next_op, name='output')     
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op, feed_dict=X:data)
while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        saver.save(sess, 'tmp/', global_step=1002)
    break

然后您可以恢复相同的模型进行推理,如下所示:

恢复保存的迭代器

np.random.seed(42)
data = np.random.random([4, 4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_meta_graph('tmp/-1002.meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()

# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')

X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op, feed_dict=X:data)
while True:
try:
    print(sess.run(output))
except tf.errors.OutOfRangeError:
    break

【讨论】:

这很有帮助。是否可以使用新数据集重新初始化迭代器?看来我真正想做的是保存可重新初始化的迭代器本身(不是dataset_init_op)。然后我想恢复那个迭代器并用一个新的数据集创建一个新的初始化器。但是当我尝试保存迭代器操作时,TF 会抱怨。 @masonk 我不确定你想说什么。如果您正在谈论使用具有相同结构的不同数据(例如训练集与测试集)初始化相同的dataset 管道,那么只需使用新数据作为参数运行dataset_init_op feed_dict 如上图所示。这是用新数据初始化可重新初始化的迭代器。如果输入数据的结构不同,则需要查看feedable iterator。 我今天在 tensorflow/tensorflow 上提交了这个错误,它描述了我的用例:github.com/tensorflow/tensorflow/issues/20098 最后我确实使用了一个可馈送的迭代器。我认为只有当你的 dset 在内存中时,你正在做的喂占位符变量才是好的。我想从新数据集中重新初始化一个可重新初始化的迭代器。【参考方案4】:

我无法解决与初始化迭代器相关的问题,但由于我使用map 方法预处理了我的数据集,并且我应用了由py_func 包装的 Python 操作定义的转换,这些转换无法序列化以进行存储\restoreing,无论如何我都必须初始化我的数据集。

所以,剩下的问题是当我恢复它时如何将数据提供给我的图表。我在迭代器输出和网络输入之间放置了一个 tf.identity 节点。恢复后,我将数据提供给身份节点。我后来发现的一个更好的解决方案是使用placeholder_with_default(),如this answer 中所述。

【讨论】:

以上是关于恢复使用迭代器的 TensorFlow 模型的主要内容,如果未能解决你的问题,请参考以下文章

无法使用经过训练的 Tensorflow 模型

Tensorflow:如何使用恢复的模型?

Tensorflow:保存和恢复模型参数

TensorFlow Saver的使用方法

TensorFlow 模型恢复(恢复训练似乎从头开始)

使用TensorFlow进行股票价格预测的简单深度学习模型