从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理
Posted
技术标签:
【中文标题】从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理【英文标题】:loading a graph from .meta file from Tensorflow in c++ for inference 【发布时间】:2018-07-28 05:22:44 【问题描述】:我已经使用 tensorflow 1.5.1 训练了一些模型,并且我有这些模型的检查点(包括 .ckpt 和 .meta 文件)。现在我想使用这些文件在 C++ 中进行推理。
在 python 中,我会执行以下操作来保存和加载图形和检查点。 保存:
images = tf.placeholder(...) // the input layer
//the graph def
output = tf.nn.softmax(net) // the output layer
tf.add_to_collection('images', images)
tf.add_to_collection('output', output)
为了推断,我恢复图形和检查点,然后从集合中恢复输入和输出层,如下所示:
meta_file = './models/last-100.meta'
ckpt_file = './models/last-100'
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, ckpt_file)
images = tf.get_collection('images')
output = tf.get_collection('output')
outputTensors = sess.run(output, feed_dict=images: np.array(an_image))
现在假设我像往常一样在 python 中进行了保存,我如何使用 python 中的简单代码在 c++ 中进行推理和恢复?
我找到了示例和教程,但对于 tensorflow 0.7 0.12 版本,相同的代码不适用于 1.5 版本。我在 tensorflow 网站上没有找到使用 c++ API 恢复模型的教程。
【问题讨论】:
见github.com/PatWie/tensorflow_inference 【参考方案1】:为了这个thread。我会将我的评论改写为答案。
发布完整示例需要 CMake 设置或将文件放入特定目录以运行 bazel。因为我喜欢第一种方式,它会打破这篇文章的所有限制以涵盖所有部分,我想重定向到我为 TF > v1.5 测试过的complete implementation in C99, C++, GO without Bazel。
在 C++ 中加载图表并不比在 Python 中困难多少,鉴于您已经从源代码编译了 TensorFlow。
从创建一个 MWE 开始,它会创建一个非常转储的网络图,这对于弄清楚事情是如何工作的总是一个好主意:
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './exported/my_model')
关于这部分的 SO 这里可能有很多答案。所以我就让它留在这里,不再解释。
在 Python 中加载
在用其他语言做东西之前,我们可以尝试在 python 中正确地做——从某种意义上说:我们只需要用 C++ 重写它。 甚至在 python 中恢复也很容易,例如:
import tensorflow as tf
with tf.Session() as sess:
# load the computation graph
loader = tf.train.import_meta_graph('./exported/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, './exported/my_model')
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
它没有帮助,因为大多数这些 API 端点在 C++ API 中不存在(还没有?)。另一个版本是
import tensorflow as tf
with tf.Session() as sess:
metaGraph = tf.train.import_meta_graph('./exported/my_model.meta')
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
sess.run(restore_op, filename_tensor_name: './exported/my_model')
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
等一下。您始终可以使用print(dir(object))
来获取restore_op_name
、...等属性。
与其他操作一样,恢复模型是 TensorFlow 中的一项操作。我们只是调用此操作并提供路径(字符串张量)作为输入。我们甚至可以编写自己的restore
操作
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, filename_tensor_name: fn)
即使这看起来很奇怪,但现在在 C++ 中做同样的事情有很大帮助。
在 C++ 中加载
从平常的东西开始
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/public/session_options.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <string>
#include <iostream>
typedef std::vector<std::pair<std::string, tensorflow::Tensor>> tensor_dict;
int main(int argc, char const *argv[])
const std::string graph_fn = "./exported/my_model.meta";
const std::string checkpoint_fn = "./exported/my_model";
// prepare session
tensorflow::Session *sess;
tensorflow::SessionOptions options;
TF_CHECK_OK(tensorflow::NewSession(options, &sess));
// here we will put our loading of the graph and weights
return 0;
您应该能够通过将其放入 TensorFlow 存储库并使用 bazel 来编译它,或者只需按照说明 here 使用 CMake。
我们需要创建这样一个由tf.train.import_meta_graph
创建的meta_graph
。这可以通过
tensorflow::MetaGraphDef graph_def;
TF_CHECK_OK(ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def));
在 C++ 中,从文件中读取图形不与在 Python 中导入图形相同。我们需要在会话中通过
TF_CHECK_OK(sess->Create(graph_def.graph_def()));
通过查看上面奇怪的pythonrestore
函数:
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
我们可以用 C++ 编写等效的代码
const std::string restore_op_name = graph_def.saver_def().restore_op_name()
const std::string filename_tensor_name = graph_def.saver_def().filename_tensor_name()
有了这个,我们就可以运行操作了
sess->Run(feed_dict, // inputs
, // output_tensor_names (we do not need them)
restore_op, // target_node_names
nullptr) // outputs (there are no outputs this time)
创建 feed_dict 本身可能是一个帖子,这个答案已经足够长了。它只涵盖最重要的东西。我想重定向到我针对 TF > v1.5 测试过的complete implementation in C99, C++, GO without Bazel。这并不难——在plain C version 的情况下它可能会变得很长。
【讨论】:
所以要恢复/保存模型,我必须在 C++ 中调用session->Run()
?
就像你在 python 中所做的一样。以上是关于从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理的主要内容,如果未能解决你的问题,请参考以下文章
用于从 C++ 自动生成的 python 模块的 TensorFlow 源
在 C++ 中为 Tensorflow 模型定义 feed_dict
在 Windows 上构建 C++ 项目时,Tensorflow 2.3 无法解析机器生成的文件中的外部符号