如何在 C++ 中保存和恢复 TensorFlow 图及其状态?

Posted

技术标签:

【中文标题】如何在 C++ 中保存和恢复 TensorFlow 图及其状态?【英文标题】:How to save and restore a TensorFlow graph and its state in C++? 【发布时间】:2016-05-29 10:27:07 【问题描述】:

我正在使用 C++ 中的 TensorFlow 训练我的模型。 Python 仅用于构建图形。那么有没有办法纯粹在 C++ 中保存和恢复图形及其状态?我知道 Python 类 tf.train.Saver,但据我了解,它在 C++ 中不存在。

【问题讨论】:

【参考方案1】:

tf.train.Saver 类目前仅存在于 Python 中,但是 (i) 它是从 TensorFlow ops 构建的,您可以在 C++ 中运行,并且 (ii) 它公开了 Saver.as_saver_def() 方法让您获得一个SaverDef protocol buffer,其中包含您必须运行以保存或恢复模型的操作名称。

在 Python 中,您可以获得保存和恢复操作的名称,如下所示:

saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()

# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name

# The name of the target operation you must run when restoring.
print saver_def.restore_op_name

# The name of the target operation you must run when saving.
print saver_def.save_tensor_name

在 C++ 中要从检查点恢复,您调用 Session::Run(),将检查点文件的名称输入为 saver_def.filename_tensor_name,目标操作为 saver_def.restore_op_name。要保存另一个检查点,请调用Session::Run(),再次输入检查点文件的名称为saver_def.filename_tensor_name,并获取saver_def.save_tensor_name 的值。

【讨论】:

好建议!我不得不从一个字符串的末尾删除一个“:0”。此外,相对路径在恢复模型期间不起作用。张量创建:tf::Tensor string( tf::DT_STRING, tf::TensorShape( 1, 1 ) ); 馈送字符串:string.matrix< std::string >()( 0, 0 ) = file_path_ + filename; 执行:TF_CHECK_OK( session_->Run( "save/Const:0", string , , "save/control_dependency" , nullptr ) ); @Trevir,mrry:你能发一下sn-p吗?我是 tensorflow 的新手,文档没有帮助。我非常感谢你! @Surferonthefall:前一条评论包含所有必要的代码。使用 python 脚本获取正确的操作名称,例如“保存/常量:0”。之后就可以通过 session->run 方法在 c++ 中使用操作名了。 惊人的hacky解决方案! Python 脚本必须包含 saver = tf.train.Saver(...) 行。我可以确认必须将“save/control_dependency:0”重命名为“save/control_dependency” @mrry 你能看看这个吗? github.com/tensorflow/tensorflow/issues/…【参考方案2】:

最新的 TensorFlow 版本包含一些帮助函数,可以在没有 Python 的 C++ 中执行相同的操作。这些是从 pip 包 ($HOME/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h) 中的 ProtoBuf 生成的。

// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = graph_def.saver_def().filename_tensor_name(), checkpointPathTensor;
status = sess->Run(feed_dict, , graph_def.saver_def().save_tensor_name(), nullptr);

// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = graph_def.saver_def().filename_tensor_name(), checkpointPathTensor;
status = sess->Run(feed_dict, , graph_def.saver_def().restore_op_name(), nullptr);

这是基于恢复模型的无证 python 方式 (more details)

def restore(sess, metaGraph, fn):
    restore_op_name = metaGraph.as_saver_def().restore_op_name   # u'save/restore_all'
    restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
    filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name  # u'save/Const'
    sess.run(restore_op, filename_tensor_name: fn)

对于一个工作和完整的version see here。

【讨论】:

以上是关于如何在 C++ 中保存和恢复 TensorFlow 图及其状态?的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 c++ 在 tensorflow 中保存模型

在 tensorflow 中恢复图形失败,因为没有要保存的变量

Tensorflow:如何使用恢复的模型?

风格迁移:在 tensorflow 1.15.0 中保存和恢复检查点/模型

Tensorflow:保存和恢复模型参数

AI - TensorFlow - 示例05:保存和恢复模型