使用 tf.train.MonitoredTrainingSession 时如何获取全局步长
Posted
技术标签:
【中文标题】使用 tf.train.MonitoredTrainingSession 时如何获取全局步长【英文标题】:How to obtain global step when using tf.train.MonitoredTrainingSession 【发布时间】:2018-01-19 10:04:17 【问题描述】:当我们在Saver.save
中指定global_step时,它会将global_step存储为checkpoint后缀。
# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)
我们可以像这样恢复检查点并获取存储在检查点中的最后一个全局步骤:
# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
如果我们使用tf.train.MonitoredTrainingSession
,那么将全局步骤保存到检查点并获取gstep
的等效方法是什么?
编辑 1
按照Maxim的建议,我在tf.train.MonitoredTrainingSession
之前创建了global_step
变量,并像这样添加了CheckpointSaverHook
:
global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
save_steps=5,
checkpoint_basename=(checkpoints_prefix + ".ckpt"))
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=is_chief,
hooks=[sync_replicas_hook, save_checkpoint_hook],
config=config) as session:
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
print("current global step=" + str(gstep))
我可以看到它生成的检查点文件类似于Saver.saver
所做的。但是,它无法从检查点检索全局步骤。请告知我应该如何解决这个问题?
【问题讨论】:
【参考方案1】:您可以通过tf.train.get_global_step()
或tf.train.get_or_create_global_step()
函数获取当前全局步骤。后者应在训练开始前调用。
对于被监控的会话,将tf.train.CheckpointSaverHook
添加到hooks
,它内部使用定义的全局步长张量在每N步后保存模型。
【讨论】:
我已对原始帖子进行了编辑,说明我无法检索 global_step。请你看看好吗?以上是关于使用 tf.train.MonitoredTrainingSession 时如何获取全局步长的主要内容,如果未能解决你的问题,请参考以下文章
在使用加载数据流步骤的猪中,使用(使用 PigStorage)和不使用它有啥区别?
Qt静态编译时使用OpenSSL有三种方式(不使用,动态使用,静态使用,默认是动态使用)