如何使用 c++ 在 tensorflow 中保存模型

Posted

技术标签:

【中文标题】如何使用 c++ 在 tensorflow 中保存模型【英文标题】:How to save a model in tensorflow by using c++ 【发布时间】:2017-11-19 09:55:45 【问题描述】:

如何使用 c++ 在 Tensorflow 中保存模型?我在谷歌和百度上搜索过,但没有找到任何解决方案。然后我看了tensorflow的api文档,介绍的少了关于C++的介绍

【问题讨论】:

【参考方案1】:

模型保存仅在 Python 中实现。目前无法使用 C++ API 保存模型。 C++ API 允许您加载和使用模型,而不是训练或保存它们。

【讨论】:

【参考方案2】:

假设您对 tensorflow C++ API 有基本的了解,并且知道如何使用 C++ API 构建图形。您可以使用这两个功能:

    tensorflow::WriteTextProto() :您可以从 tensorflow::Scope::ToGraphDef() 获得 tensorflow::GraphDef(代表您定义的所有操作,例如加、乘、均值 .... 等),将 tensorflow::GraphDef 保存到文本 protobuf 文件

    tensorflow::checkpoint::TensorSliceWriter 将参数矩阵的当前状态保存到外部文件(检查点),有点复杂,但对我来说效果很好

首先,您必须通过调用tensorflow::Session::Run 来获取训练参数,这会将参数矩阵列表返回给output_tensor(参见下面的示例):

std::vector<tensorflow::Tensor> output_tensor; 
tensorflow::Session::Run(, "name_of_param_mtx_1", "name_of_param_mtx_2",, , &output_tensor);

上面的name_of_param_mtx_1name_of_param_mtx_2 应该是tensorflow::Variable 中的参数矩阵的名称,例如

auto name_of_param_mtx_1 = tensorflow::ops::Variable (root.WithOpName("name_of_param_mtx_1"), 7, 17, tensorflow::DT_FLOAT);

那么你需要为tensorflow::checkpoint::TensorSliceWriter准备以下内容:

通过调用tensorflow::Tensor.tensor_data().data()获取参数原始数据的基地址 每个tensorflow::Tensor 的形状,通过调用tensorflow::Tensor::dim_size(NUM_DIMENSION)。例如 7x17 2D 参数矩阵,NUM_DIMENSION 可以是 0 和 1,其中 tensorflow::Tensor::dim_size(0) 为 7,tensorflow::Tensor::dim_size(1) 为 17。 此检查点的名称,该名称必须与一个文件中的其他检查点不同 通过调用tensorflow::TensorSlice::ParseOrDie("-:-")创建tensorflow::TensorSlice,似乎tensorflow::TensorSlice::ParseOrDie的唯一参数将在内部分析,例如-:- 表示获取矩阵的所有项。如果用户只想要训练过的参数矩阵的一部分,例如只取所有行的第二列,那么字符串参数可能是 -:2 ,我还没有想出 tensorflow::TensorSlice::ParseOrDie 的这种高级用法。

希望对您有所帮助。

【讨论】:

以上是关于如何使用 c++ 在 tensorflow 中保存模型的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow C++ 中定义一个自定义的有状态操作来保存变量的值?

从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理

如何在 C++ 中使用 TensorFlow Estimator?

如何使用 tensorflow 在 C++ 中训练模型?

TensorFlow C++ 评估性能比 Python 一更差

如何构建和使用 Google TensorFlow C++ api