在张量流中导入图形时使用新操作
Posted
技术标签:
【中文标题】在张量流中导入图形时使用新操作【英文标题】:Using new op while importing graph in tensorflow 【发布时间】:2017-05-09 22:18:46 【问题描述】:我是 TensorFlow 新手。我正在尝试使用检查点文件导入经过训练的 TensorFlow 网络。我正在使用的网络有一个自定义操作,当我在 Python 中使用它时可以正常工作。但是,我必须冻结图表,因为我必须使用 C++ API。我正在使用 TensorFlow 基本目录中的以下命令调用 freeze_graph
:
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=<local path>/data/graph_vgg.pb --input_checkpoint=<local path>/data/VGGnet_fast_rcnn_iter_70000.ckpt --output_node_names="cls_prob,bbox_pred" --output_graph=<local path>/graph_frozen.pb
但是,当我尝试冻结图表时出现以下错误。
Traceback (most recent call last):
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 202, in <module>
app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 134, in main
FLAGS.variable_names_blacklist)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 99, in freeze_graph
_ = importer.import_graph_def(input_graph_def, name="")
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/framework/importer.py", line 260, in import_graph_def
raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named RoiPool in defined operations.
输入图有一个节点,其操作类型为 RoiPool
,TensorFlow 无法识别该节点。我调查了引发此错误的代码,它看起来像是操作未在 TensorFlow 中注册。我有内置的.so
文件。我应该把它复制到某个地方吗?我在网上找不到类似的东西。任何帮助或指示都会很棒。我在这个问题上花了很多时间。该代码在 python 中运行良好,使用 op 的层位于项目目录中。请帮助我了解我需要做什么才能使其正常工作。
编辑:这是网络中使用的code of custom op。
【问题讨论】:
您需要将自定义操作链接到冻结图二进制文件中。您当然可以通过编辑 freeze_graph 二进制文件的构建文件并使用自定义操作作为额外依赖项进行重建来做到这一点。您可能还想在 github 上提交功能请求,要求在 freeze_graph 中提供一些更易于使用的自定义操作支持。 @PeterHawkins 感谢您的回复。我还有一个问题:存储user_ops
的首选位置是:<tensorflow_base>/tensorflow/user_ops
或<tensorflow_base>/tensorflow/core/user_ops
。
我尝试了建议并使用我的 user_op 作为依赖项构建了 freeze_graph。 Freeze_graph 仍然给出同样的错误。还有其他建议或建议吗?
【参考方案1】:
我不熟悉特定的 RoiPooling 实现,但我通常设置需要冻结的自定义操作的方式是 roi_pooling_op.cc 和关联的 python 文件(定义渐变并导入 *.so)都位于在 //tensorflow/user_ops 中。
//tensorflow/user_ops目录下的BUILD文件应该有
tf_custom_op_library(
name = "roi_pooling_op.so",
srcs = ["roipooling_op.cc"],
)
py_library(
name = "roi_pooling_op_py",
srcs = ["roi_pooling.py"],
data = [":roi_pooling_op.so"],
srcs_version = "PY2AND3",
)
* Tensorflow 文档中没有提到 data = [":roi_pooling_op.so"]
,但这样您就不必深入了解本地 bazel-bin 目录,而是可以使用 tf.resource_loader.get_path_to_datafile
导入 *.so
_roi_pooling_module = tf.load_op_library(tf.resource_loader.get_path_to_datafile("roi_pooling_op.so"))
roi_pool = _roi_pooling_module.roi_pool
roi_pool_grad = _roi_pooling_module.roi_pool_grad
@ops.RegisterGradient("RoiPool")
def _roi_pool_grad(op, grad, _):
grad_out = roi_pool_grad(...)
return grad_out, None
更新冻结构建,在 BUILD 文件 //tensorflow/python/tools 目录中,添加 "//tensorflow/user_ops:roi_pooling_op_py",
作为 freeze_graph py_binary 的依赖项。
最后重新构建和安装所有东西(custom-op、freeze_graph 和 pip 包/轮)
bazel build --config opt //tensorflow/user_ops:roi_pooling_op.so
bazel build --config opt //tensorflow/user_ops:roi_pooling_op_py
bazel build --config opt //tensorflow/python/tools:freeze_graph
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install --ignore-installed --upgrade /tmp/tensorflow_pkg/tensorflow-1.2.1-py2-none-any.whl
现在你可以在你的 python 代码中使用它了
from tensorflow.user_ops import roi_pooling
现在您应该可以毫无问题地冻结图表了。
【讨论】:
【参考方案2】:我遵循了 Jared 的回答,我认为它让我大部分时间都受益,但我需要来自 https://***.com/a/37556646/7004026 的最后一篇文章。我在调用import_graph_def
之前直接在freeze_graph.py
中插入了tf.load_op_library('/path/to/custom_op.so')
。然后我就可以冻结我的图表了。
【讨论】:
以上是关于在张量流中导入图形时使用新操作的主要内容,如果未能解决你的问题,请参考以下文章
Keras,张量流在 sublime text 和 spyder 中导入错误,但在命令行中工作