转 如何阅读TensorFlow源码
Posted xiaoniu-666
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了转 如何阅读TensorFlow源码相关的知识,希望对你有一定的参考价值。
通过bazel学习之后,大概了解了TensorFlow的项目的源文件和描述文件。
下面是一篇不错的介绍,搬砖here。
在静下心来默默看了大半年机器学习的资料并做了些实践后,打算学习下现在热门的TensorFlow的实现,毕竟系统这块和自己关系较大。本文会简单的说明一下如何阅读TensorFlow的源码。最重要的是了解其构建工具bazel以及脚本语言调用c或cpp的包裹工具swig。这里假设大家对bazel及swig以及有所了解(不了解的可以google下)。要看代码首先要知道代码怎么构建,因此本文的一大部分会关注构建这块。
如果从源码构建TensorFlow会需要执行如下命令:
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
对应的BUILD文件的rule为:
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
"MANIFEST.in",
"README",
"setup.py",
"//tensorflow/core:framework_headers",
":other_headers",
":simple_console",
"//tensorflow:tensorflow_py",
"//tensorflow/examples/tutorials/mnist:package",
"//tensorflow/models/embedding:package",
"//tensorflow/models/image/cifar10:all_files",
"//tensorflow/models/image/mnist:convolutional",
"//tensorflow/models/rnn:package",
"//tensorflow/models/rnn/ptb:package",
"//tensorflow/models/rnn/translate:package",
"//tensorflow/tensorboard",
],
)
sh_binary在这里的主要作用是生成data的这些依赖。一个一个来看,一开始的三个文件MANIFEST.in、README、setup.py是直接存在的,因此不会有什么操作。
“//tensorflow/core:framework_headers”:
其对应的rule为:
filegroup(
name = "framework_headers",
srcs = [
"framework/allocator.h",
......
"util/device_name_utils.h",
],
)
这里filegroup的作用是给这一堆头文件一个别名,方便其他rule引用。
“:other_headers”:
rule为:
transitive_hdrs(
name = "other_headers",
deps = [
"//third_party/eigen3",
"//tensorflow/core:protos_all_cc",
],
)
transitive_hdrs的定义在:
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
outputs = set()
for dep in ctx.attr.deps:
outputs += dep.cc.transitive_headers
return struct(files=outputs)
_transitive_hdrs = rule(attrs={
"deps": attr.label_list(allow_files=True,
providers=["cc"]),
},
implementation=_transitive_hdrs_impl,)
def transitive_hdrs(name, deps=[], **kwargs):
_transitive_hdrs(name=name + "_gather",
deps=deps)
native.filegroup(name=name,
srcs=[":" + name + "_gather"])
其作用依旧是收集依赖需要的头文件。
“:simple_console”:
其rule为:
py_binary(
name = "simple_console",
srcs = ["simple_console.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
py_library( name = "tensorflow_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python"], )
simple_console.py的代码的主要部分是:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import code
import sys
def main(_):
"""Run an interactive console."""
code.interact()
return 0
if __name__ == ‘__main__‘:
sys.exit(main(sys.argv))
可以看到起通过deps = [“//tensorflow/python”]构建了依赖包,然后生成了对应的执行文件。看下依赖的rule规则。
//tensorflow/python对应的rule为:
py_library(
name = "python",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__pkg__"],
deps = [
":client",
":client_testlib",
":framework",
":framework_test_lib",
":kernel_tests/gradient_checker",
":platform",
":platform_test",
":summary",
":training",
"//tensorflow/contrib:contrib_py",
],
)
py_library(
name = "training",
srcs = glob(
["training/**/*.py"],
exclude = ["**/*test*"],
),
srcs_version = "PY2AND3",
deps = [
":client",
":framework",
":lib",
":ops",
":protos_all_py",
":pywrap_tensorflow",
":training_ops",
],
)
这里其依赖的pywrap_tensorflow的rule为:
tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
swig_includes = [
"client/device_lib.i",
"client/events_writer.i",
"client/server_lib.i",
"client/tf_session.i",
"framework/python_op_gen.i",
"lib/core/py_func.i",
"lib/core/status.i",
"lib/core/status_helper.i",
"lib/core/strings.i",
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"platform/numpy.i",
"util/port.i",
"util/py_checkpoint_reader.i",
],
deps = [
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
)
tf_py_wrap_cc为其自己实现的一个rule,这里的.i就是SWIG的interface文件。来看下其实现:
def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
module_name = name.split("/")[-1]
# Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
# and use that as the name for the rule producing the .so file.
cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
extra_deps = []
_py_wrap_cc(name=name + "_py_wrap",
srcs=srcs,
swig_includes=swig_includes,
deps=deps + extra_deps,
module_name=module_name,
py_module_name=name)
native.cc_binary(
name=cc_library_name,
srcs=[module_name + ".cc"],
copts=(copts + ["-Wno-self-assign", "-Wno-write-strings"]
+ tf_extension_copts()),
linkopts=tf_extension_linkopts(),
linkstatic=1,
linkshared=1,
deps=deps + extra_deps)
native.py_library(name=name,
srcs=[":" + name + ".py"],
srcs_version="PY2AND3",
data=[":" + cc_library_name])
_py_wrap_cc = rule(attrs={
"srcs": attr.label_list(mandatory=True,
allow_files=True,),
"swig_includes": attr.label_list(cfg=DATA_CFG,
allow_files=True,),
"deps": attr.label_list(allow_files=True,
providers=["cc"],),
"swig_deps": attr.label(default=Label(
"//tensorflow:swig")), # swig_templates
"module_name": attr.string(mandatory=True),
"py_module_name": attr.string(mandatory=True),
"swig_binary": attr.label(default=Label("//tensorflow:swig"),
cfg=HOST_CFG,
executable=True,
allow_files=True,),
},
outputs={
"cc_out": "%{module_name}.cc",
"py_out": "%{py_module_name}.py",
},
implementation=_py_wrap_cc_impl,)
_py_wrap_cc_impl的实现为:
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
srcs = ctx.files.srcs
if len(srcs) != 1:
fail("Exactly one SWIG source file label must be specified.", "srcs")
module_name = ctx.attr.module_name
cc_out = ctx.outputs.cc_out
py_out = ctx.outputs.py_out
src = ctx.files.srcs[0]
args = ["-c++", "-python"]
args += ["-module", module_name]
args += ["-l" + f.path for f in ctx.files.swig_includes]
cc_include_dirs = set()
cc_includes = set()
for dep in ctx.attr.deps:
cc_include_dirs += [h.dirname for h in dep.cc.transitive_headers]
cc_includes += dep.cc.transitive_headers
args += ["-I" + x for x in cc_include_dirs]
args += ["-I" + ctx.label.workspace_root]
args += ["-o", cc_out.path]
args += ["-outdir", py_out.dirname]
args += [src.path]
outputs = [cc_out, py_out]
ctx.action(executable=ctx.executable.swig_binary,
arguments=args,
mnemonic="PythonSwig",
inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes +
ctx.attr.swig_deps.files),
outputs=outputs,
progress_message="SWIGing {input}".format(input=src.path))
return struct(files=set(outputs))
# If possible, read swig path out of "swig_path" generated by configure
SWIG=swig
SWIG_PATH=tensorflow/tools/swig/swig_path
if [ -e $SWIG_PATH ]; then
SWIG=`cat $SWIG_PATH`
fi
# If this line fails, rerun configure to set the path to swig correctly
"$SWIG" "[email protected]"
可以看到起就是调用了swig命令。
“//tensorflow:tensorflow_py”:
其rule为:
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
可以看到起主要依赖了我们上面生成的”//tensorflow/python”这个module。
剩余的几个其实和主框架关系不大,主要是生成一些model、文档啥的。
现在清楚了其构建链后,我们来看个简单的程序,其通过梯度下降算法求线性拟合的W和b。我们会从这个例子入手看下如何找到其使用的函数的具体实现的源码位置:
(python3.5)? tmp cat th.py
import tensorflow as tf
import numpy as np
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for step in range(0, 201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
(python3.5)? tmp python th.py
0 [ 0.42190057] [ 0.17155224]
20 [ 0.1743494] [ 0.26045772]
40 [ 0.11817314] [ 0.29033473]
60 [ 0.10444205] [ 0.29763755]
80 [ 0.10108578] [ 0.29942256]
100 [ 0.10026541] [ 0.29985884]
120 [ 0.10006487] [ 0.2999655]
140 [ 0.10001585] [ 0.29999158]
160 [ 0.10000388] [ 0.29999796]
180 [ 0.10000096] [ 0.29999951]
200 [ 0.10000025] [ 0.29999989]
from tensorflow.python import *
? python grep ‘class Variable(‘ -R ./* ./ops/variables.py:class Variable(object):
对于tf.Session等也可以用同样的方法定位。我们来找个走SWIG包裹的,如果我们去看sess.run,我们会看到如下的代码:
return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, run_metadata)
这里tf_session就是一个SWIG包裹的模块:
from tensorflow.python import pywrap_tensorflow as tf_session
pywrap_tensorflow在源码里是找不到的,因为这个得从SWIG生成后才有,我们可以从.i文件里找下TF_Run的声明,或者直接grep下这个函数:
? tensorflow grep ‘TF_Run(‘ -R ./* ./core/client/tensor_c_api.cc:void TF_Run(TF_Session* s, const TF_Buffer* run_options,
这样就可以看其实现了:
void TF_Run(TF_Session* s, const TF_Buffer* run_options, // Input tensors const char** c_input_names, TF_Tensor** c_inputs, int ninputs, // Output tensors const char** c_output_tensor_names, TF_Tensor** c_outputs, int noutputs, // Target nodes const char** c_target_node_names, int ntargets, TF_Buffer* run_metadata, TF_Status* status) { TF_Run_Helper(s, nullptr, run_options, c_input_names, c_inputs, ninputs, c_output_tensor_names, c_outputs, noutputs, c_target_node_names, ntargets, run_metadata, status); }
以上是关于转 如何阅读TensorFlow源码的主要内容,如果未能解决你的问题,请参考以下文章
分享《21个项目玩转深度学习:基于TensorFlow的实践详解》+PDF+源码+何之源