如何在 C++ 中保存和恢复 TensorFlow 图及其状态?
Posted
技术标签:
【中文标题】如何在 C++ 中保存和恢复 TensorFlow 图及其状态?【英文标题】:How to save and restore a TensorFlow graph and its state in C++? 【发布时间】:2016-05-29 10:27:07 【问题描述】:我正在使用 C++ 中的 TensorFlow 训练我的模型。 Python 仅用于构建图形。那么有没有办法纯粹在 C++ 中保存和恢复图形及其状态?我知道 Python 类 tf.train.Saver
,但据我了解,它在 C++ 中不存在。
【问题讨论】:
【参考方案1】:tf.train.Saver
类目前仅存在于 Python 中,但是 (i) 它是从 TensorFlow ops 构建的,您可以在 C++ 中运行,并且 (ii) 它公开了 Saver.as_saver_def()
方法让您获得一个SaverDef
protocol buffer,其中包含您必须运行以保存或恢复模型的操作名称。
在 Python 中,您可以获得保存和恢复操作的名称,如下所示:
saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()
# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name
# The name of the target operation you must run when restoring.
print saver_def.restore_op_name
# The name of the target operation you must run when saving.
print saver_def.save_tensor_name
在 C++ 中要从检查点恢复,您调用 Session::Run()
,将检查点文件的名称输入为 saver_def.filename_tensor_name
,目标操作为 saver_def.restore_op_name
。要保存另一个检查点,请调用Session::Run()
,再次输入检查点文件的名称为saver_def.filename_tensor_name
,并获取saver_def.save_tensor_name
的值。
【讨论】:
好建议!我不得不从一个字符串的末尾删除一个“:0”。此外,相对路径在恢复模型期间不起作用。张量创建:tf::Tensor string( tf::DT_STRING, tf::TensorShape( 1, 1 ) );
馈送字符串:string.matrix< std::string >()( 0, 0 ) = file_path_ + filename;
执行:TF_CHECK_OK( session_->Run( "save/Const:0", string , , "save/control_dependency" , nullptr ) );
@Trevir,mrry:你能发一下sn-p吗?我是 tensorflow 的新手,文档没有帮助。我非常感谢你!
@Surferonthefall:前一条评论包含所有必要的代码。使用 python 脚本获取正确的操作名称,例如“保存/常量:0”。之后就可以通过 session->run 方法在 c++ 中使用操作名了。
惊人的hacky解决方案! Python 脚本必须包含 saver = tf.train.Saver(...)
行。我可以确认必须将“save/control_dependency:0”重命名为“save/control_dependency”
@mrry 你能看看这个吗? github.com/tensorflow/tensorflow/issues/…【参考方案2】:
最新的 TensorFlow 版本包含一些帮助函数,可以在没有 Python 的 C++ 中执行相同的操作。这些是从 pip 包 ($HOME/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h
) 中的 ProtoBuf 生成的。
// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = graph_def.saver_def().filename_tensor_name(), checkpointPathTensor;
status = sess->Run(feed_dict, , graph_def.saver_def().save_tensor_name(), nullptr);
// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = graph_def.saver_def().filename_tensor_name(), checkpointPathTensor;
status = sess->Run(feed_dict, , graph_def.saver_def().restore_op_name(), nullptr);
这是基于恢复模型的无证 python 方式 (more details)
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, filename_tensor_name: fn)
对于一个工作和完整的version see here。
【讨论】:
以上是关于如何在 C++ 中保存和恢复 TensorFlow 图及其状态?的主要内容,如果未能解决你的问题,请参考以下文章
在 tensorflow 中恢复图形失败,因为没有要保存的变量