如何使用 c++ 在 tensorflow 中保存模型
Posted
技术标签:
【中文标题】如何使用 c++ 在 tensorflow 中保存模型【英文标题】:How to save a model in tensorflow by using c++ 【发布时间】:2017-11-19 09:55:45 【问题描述】:如何使用 c++ 在 Tensorflow 中保存模型?我在谷歌和百度上搜索过,但没有找到任何解决方案。然后我看了tensorflow的api文档,介绍的少了关于C++的介绍
【问题讨论】:
【参考方案1】:模型保存仅在 Python 中实现。目前无法使用 C++ API 保存模型。 C++ API 允许您加载和使用模型,而不是训练或保存它们。
【讨论】:
【参考方案2】:假设您对 tensorflow C++ API 有基本的了解,并且知道如何使用 C++ API 构建图形。您可以使用这两个功能:
tensorflow::WriteTextProto()
:您可以从 tensorflow::Scope::ToGraphDef()
获得 tensorflow::GraphDef
(代表您定义的所有操作,例如加、乘、均值 .... 等),将 tensorflow::GraphDef
保存到文本 protobuf 文件
tensorflow::checkpoint::TensorSliceWriter
将参数矩阵的当前状态保存到外部文件(检查点),有点复杂,但对我来说效果很好
首先,您必须通过调用tensorflow::Session::Run
来获取训练参数,这会将参数矩阵列表返回给output_tensor
(参见下面的示例):
std::vector<tensorflow::Tensor> output_tensor;
tensorflow::Session::Run(, "name_of_param_mtx_1", "name_of_param_mtx_2",, , &output_tensor);
上面的name_of_param_mtx_1
和name_of_param_mtx_2
应该是tensorflow::Variable
中的参数矩阵的名称,例如
auto name_of_param_mtx_1 = tensorflow::ops::Variable (root.WithOpName("name_of_param_mtx_1"), 7, 17, tensorflow::DT_FLOAT);
那么你需要为tensorflow::checkpoint::TensorSliceWriter
准备以下内容:
tensorflow::Tensor.tensor_data().data()
获取参数原始数据的基地址
每个tensorflow::Tensor
的形状,通过调用tensorflow::Tensor::dim_size(NUM_DIMENSION)
。例如 7x17 2D 参数矩阵,NUM_DIMENSION 可以是 0 和 1,其中 tensorflow::Tensor::dim_size(0) 为 7,tensorflow::Tensor::dim_size(1) 为 17。
此检查点的名称,该名称必须与一个文件中的其他检查点不同
通过调用tensorflow::TensorSlice::ParseOrDie("-:-")
创建tensorflow::TensorSlice
,似乎tensorflow::TensorSlice::ParseOrDie
的唯一参数将在内部分析,例如-:-
表示获取矩阵的所有项。如果用户只想要训练过的参数矩阵的一部分,例如只取所有行的第二列,那么字符串参数可能是 -:2
,我还没有想出 tensorflow::TensorSlice::ParseOrDie
的这种高级用法。
希望对您有所帮助。
【讨论】:
以上是关于如何使用 c++ 在 tensorflow 中保存模型的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Tensorflow C++ 中定义一个自定义的有状态操作来保存变量的值?
从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理
如何在 C++ 中使用 TensorFlow Estimator?