组合图:C++ 是不是有等效的 TensorFlow import_graph_def?
Posted
技术标签:
【中文标题】组合图:C++ 是不是有等效的 TensorFlow import_graph_def?【英文标题】:Combining graphs: is there a TensorFlow import_graph_def equivalent for C++?组合图:C++ 是否有等效的 TensorFlow import_graph_def? 【发布时间】:2018-03-26 11:22:59 【问题描述】:我需要使用自定义输入和输出层扩展导出的模型。我发现这可以轻松完成:
with tf.Graph().as_default() as g1: # actual model
in1 = tf.placeholder(tf.float32,name="input")
ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
in2 = tf.placeholder(tf.float32,name="input")
ou2 = tf.add(in2,2.0,name="output")
gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()
with tf.Graph().as_default() as g_combined: #merge together
x = tf.placeholder(tf.float32, name="actual_input") # the new input layer
# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map="input:0": x,
return_elements=["output:0"])
# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map="input:0": y,
return_elements=["output:0"])
sess = tf.Session(graph=g_combined)
print "result is: ", sess.run(z, "actual_input:0":5) #result is: 9
这很好用。
但是,我需要提供一个指针作为网络输入,而不是传递任意形状的数据集。问题是,我在 python 中想不出任何解决方案(定义和传递指针),并且在使用C++ Api
开发网络时,我找不到与tf.import_graph_def
函数等效的方法。
这在 C++ 中有不同的名称,还是有其他方法可以在 C++ 中合并两个图形/模型?
感谢您的建议
【问题讨论】:
【参考方案1】:这不像在 Python 中那么容易。
您可以使用以下内容加载GraphDef
:
#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>
tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) /* Error... */
然后你就可以用它来创建会话了:
#include <tensorflow/core/public/session.h>
tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) /* Error... */
status = session->Create(graph);
if (!status.ok()) /* Error... */
或者扩展现有的图形:
status = session->Extend(graph);
if (!status.ok()) /* Error... */
这样您可以将多个GraphDef
s 放入同一个图表中。但是,没有额外的工具来提取特定节点,也没有避免名称冲突 - 您必须自己找到节点,并且必须确保 GraphDef
s 没有冲突的操作名称。例如,我使用此函数查找名称与给定正则表达式匹配的所有节点,按名称排序:
#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>
std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex ®ex)
std::vector<const tensorflow::NodeDef *> nodes;
for (const auto &node : graph.node())
if (std::regex_match(node.name(), regex))
nodes.push_back(&node);
std::sort(nodes.begin(), nodes.end(),
[](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
return lhs->name() < rhs->name();
);
return nodes;
【讨论】:
【参考方案2】:这可以在C++中通过直接操作要合并的两个Graph的GraphDefs中的NodeDefs来实现。基本算法是定义两个 GraphDef,使用占位符作为第二个 GraphDef 的输入,并将它们重定向到第一个 GraphDef 的输出。这类似于通过将第二个电路的输入连接到第一个电路的输出来串联组合两个电路。
首先,定义示例 GraphDef,以及用于观察 GraphDef 内部的实用程序。请务必注意,两个 GraphDef 中的所有节点都必须具有唯一的名称。
Status Panel::SampleFirst(GraphDef *graph_def)
Scope root = Scope::NewRootScope();
Placeholder p1(root.WithOpName("p1"), DT_INT32);
Placeholder p2(root.WithOpName("p2"), DT_INT32);
Add add(root.WithOpName("add"), p1, p2);
return root.ToGraphDef(graph_def);
Status Panel::SampleSecond(GraphDef *graph_def)
Scope root = Scope::NewRootScope();
Placeholder q1(root.WithOpName("q1"), DT_INT32);
Placeholder q2(root.WithOpName("q2"), DT_INT32);
Add sum(root.WithOpName("sum"), q1, q2);
Multiply multiply(root.WithOpName("multiply"), sum, 4);
return root.ToGraphDef(graph_def);
void Panel::ShowGraphDef(GraphDef &graph_def)
for (int i = 0; i < graph_def.node_size(); i++)
NodeDef node_def = graph_def.node(i);
cout << "NodeDef name is " << node_def.name() << endl;
cout << "NodeDef op is " << node_def.op() << endl;
for (const string& input : node_def.input())
cout << "\t input: " << input << endl;
现在创建了两个 GraphDef,第二个 GraphDef 的输入连接到第一个 GraphDef 的输出。这是通过迭代节点并识别第一个操作节点(其输入是占位符)并将这些输入重定向到第一个 GraphDef 的输出来完成的。然后将该节点添加到第一个 GraphDef 以及所有后续节点。结果是第一个 GraphDef 由第二个 GraphDef 附加。
Status Panel::Append(vector<Tensor> *outputs)
GraphDef graph_def_first;
GraphDef graph_def_second;
TF_RETURN_IF_ERROR(SampleFirst(&graph_def_first));
TF_RETURN_IF_ERROR(SampleSecond(&graph_def_second));
for (int i = 0; i < graph_def_second.node_size(); i++)
NodeDef node_def = graph_def_second.node(i);
if (node_def.name() == "sum")
node_def.set_input(0, "p1");
node_def.set_input(1, "add");
*graph_def_first.add_node() = node_def;
ShowGraphDef(graph_def_first);
unique_ptr<Session> session(NewSession(SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph_def_first));
Tensor t1(2);
Tensor t2(3);
vector<pair<string, Tensor>> inputs = "p1", t1, "p2", t2;
TF_RETURN_IF_ERROR(session->Run(inputs, "multiply", , outputs));
return Status::OK();
这个特定的图表将接受两个输入,2 和 3,并将它们相加。然后将那个(5)的和再次加到第一个输入(2)上,然后乘以4得到28的结果。((2+3) + 2) * 4 = 28。
【讨论】:
以上是关于组合图:C++ 是不是有等效的 TensorFlow import_graph_def?的主要内容,如果未能解决你的问题,请参考以下文章
C++ 中是不是有与“for ... else”Python 循环等效的方法?
Java 中协议缓冲区分隔的 I/O 函数是不是有 C++ 等效项?
使用 QWT 和 Microsoft Visual C++ 2010 绘制 MatLab 等效图