TensorFlow 从文件中保存/加载图形
Posted
技术标签:
【中文标题】TensorFlow 从文件中保存/加载图形【英文标题】:TensorFlow saving into/loading a graph from a file 【发布时间】:2016-12-21 05:17:44 【问题描述】:根据我目前收集到的信息,有几种不同的方法可以将 TensorFlow 图转储到文件中,然后将其加载到另一个程序中,但我无法找到关于它们如何工作的明确示例/信息.我已经知道的是:
-
使用
tf.train.Saver()
将模型的变量保存到检查点文件 (.ckpt) 中并稍后恢复它们 (source)
将模型保存到 .pb 文件中,然后使用 tf.train.write_graph()
和 tf.import_graph_def()
(source) 重新加载它
从 .pb 文件加载模型,重新训练,然后使用 Bazel (source) 将其转储到新的 .pb 文件中
冻结图表以将图表和权重一起保存 (source)
使用as_graph_def()
保存模型,对于权重/变量,将它们映射为常量 (source)
但是,关于这些不同的方法,我无法解决几个问题:
-
关于检查点文件,它们是否只保存模型的训练权重?是否可以将检查点文件加载到新程序中并用于运行模型,或者它们是否只是作为在特定时间/阶段保存模型中权重的一种方式?
关于
tf.train.write_graph()
,是否也保存了权重/变量?
关于 Bazel,它是否只能从 .pb 文件中保存/加载以进行再训练?是否有一个简单的 Bazel 命令可以将图形转储到 .pb 中?
关于冻结,可以使用tf.import_graph_def()
加载冻结的图吗?
用于 TensorFlow 的 android 演示从 .pb 文件加载到 Google 的 Inception 模型中。如果我想替换我自己的 .pb 文件,我该怎么做呢?我需要更改任何本机代码/方法吗?
总的来说,所有这些方法之间究竟有什么区别?或者更广泛地说,as_graph_def()
/.ckpt/.pb 之间有什么区别?
简而言之,我正在寻找一种将图形(如各种操作等)及其权重/变量保存到文件中的方法,然后可用于加载图形和权重进入另一个程序,以供使用(不一定要继续/再培训)。
关于这个主题的文档不是很简单,所以任何答案/信息将不胜感激。
【问题讨论】:
最新/最完整的 API 是元图,它可以让您一次保存所有三个 - 1) 图 2) 参数值 3) 集合:tensorflow.org/versions/r0.10/how_tos/meta_graph/index.html 【参考方案1】:有很多方法可以解决在 TensorFlow 中保存模型的问题,这可能会让人有点困惑。依次回答每个子问题:
检查点文件(例如,通过在 tf.train.Saver
对象上调用 saver.save()
生成)仅包含权重以及在同一程序中定义的任何其他变量。要在另一个程序中使用它们,您必须重新创建关联的图结构(例如,通过运行代码重新构建它,或调用 tf.import_graph_def()
),它告诉 TensorFlow 如何处理这些权重。请注意,调用saver.save()
还会生成一个包含MetaGraphDef
的文件,其中包含一个图表以及如何将检查点的权重与该图表相关联的详细信息。详情请见the tutorial。
tf.train.write_graph()
只写图结构;不是重量。
Bazel 与读取或写入 TensorFlow 图无关。 (也许我误解了你的问题:请随时在评论中澄清。)
可以使用tf.import_graph_def()
加载冻结图。在这种情况下,权重(通常)嵌入图表中,因此您无需加载单独的检查点。
主要的变化是更新输入模型的张量的名称,以及从模型中获取的张量的名称。在 TensorFlow Android 演示中,这将对应于传递给 TensorFlowClassifier.initializeTensorFlow()
的 inputName
和 outputName
字符串。
GraphDef
是程序结构,通常不会在训练过程中改变。检查点是训练过程状态的快照,通常在训练过程的每一步都会发生变化。因此,TensorFlow 对这些类型的数据使用不同的存储格式,而底层 API 提供了不同的方式来保存和加载它们。更高级别的库,例如 MetaGraphDef
库、Keras 和 skflow 建立在这些机制之上,以提供更方便的方式来保存和恢复整个模型。
【讨论】:
这是否意味着C++ API documentation在说谎,当它说你可以加载用tf.train.write_graph()
保存的图形然后执行它?
C++ API 文档没有说谎,但它缺少一些细节。最重要的细节是,除了tf.train.write_graph()
保存的GraphDef
之外,您还需要记住在执行图形时要馈送和获取的张量的名称(上面的第5项)。
@mrry:我尝试使用 tensorflows DeepDream 示例。但它似乎需要 pb 格式的预训练模型!我运行了 Cifar10 示例,但它只创建检查点!我找不到任何 pb 文件或任何东西!如何将我的检查点转换为 deepdream 示例使用的 pb 格式?
@Coderx7 我真的认为您不能将 .ckpt 转换为 .pb,因为检查点仅包含权重和变量,并且对图形的结构一无所知
有没有简单的代码来加载一个.pb文件然后运行它?【参考方案2】:
你可以试试下面的代码:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
【讨论】:
以上是关于TensorFlow 从文件中保存/加载图形的主要内容,如果未能解决你的问题,请参考以下文章
在 TensorFlow Functional API 中保存和加载具有相同图形的多个模型