Tensorflow 中图表中的张量名称列表
Posted
技术标签:
【中文标题】Tensorflow 中图表中的张量名称列表【英文标题】:List of tensor names in graph in Tensorflow 【发布时间】:2016-05-22 01:41:53 【问题描述】:Tensorflow 中的图形对象有一个名为“get_tensor_by_name(name)”的方法。无论如何要获得有效张量名称的列表吗?
如果没有,有人知道预训练模型 inception-v3 from here 的有效名称吗?在他们的示例中,pool_3 是一个有效的张量,但所有这些张量的列表都会很好。我查看了the paper referred to,其中一些层似乎与表 1 中的大小相对应,但不是全部。
【问题讨论】:
不推荐对etarion's answer、op.values()
进行小更新。请改用op.outputs
。
【参考方案1】:
论文没有准确地反映模型。如果你从 arxiv 下载源代码,它有一个准确的模型描述,如 model.txt,其中的名称与已发布模型中的名称密切相关。
为了回答您的第一个问题,sess.graph.get_operations()
为您提供了操作列表。对于一个操作,op.name
给你名称,op.values()
给你它产生的张量列表(在 inception-v3 模型中,所有张量名称都是带有“:0”的操作名称,所以@ 987654325@是最终池化操作产生的张量。)
【讨论】:
感谢您的快速回答! model.txt 和我在这个预训练模型中看到的输出形状似乎仍然存在一些差异。例如,如果我查看“pool:0”,我猜它是第一个池化层,我得到的形状是 73x73x64,但在 model.txt 中它之后的层的输入是 73x73x80。还是我误会了什么? @john 我没有深入挖掘 model.txt 中的 cmets,我认为 cmets 中存在一些不一致之处 - 我没有发现非 cmets 中的不一致之处。对于那个池化层,之前的卷积有 64 个过滤器组(Conv 分配给 conv_2 的第二个参数),所以池化层也有 64 个通道。 80是下一个conv层的输出数... @etarion 我在哪里可以下载这个model.txt?请问可以直接给我链接吗?提前致谢 @OleksandrKhryplyvenko 它是arxiv.org/format/1512.00567v3源代码分发的一部分 @etarion 谢谢!顺便说一句,我已经构建了 inception v3 图。正如您告诉我的那样,在图形构建期间过多的内存消耗是由重复条目引起的。【参考方案2】:以上答案都是正确的。对于上述任务,我遇到了一个易于理解/简单的代码。所以在这里分享:-
import tensorflow as tf
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
printTensors("path-to-my-pbfile.pb")
【讨论】:
我从 ParseFromString 得到一个google.protobuf.message.DecodeError: Error parsing message
。但我正在使用saved_model
。为什么 TF 中的模型有这么多格式?有点烂。【参考方案3】:
查看图中的操作(你会看到很多,所以为了简短起见,我在这里只给出了第一个字符串)。
sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]
out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
【讨论】:
【参考方案4】:您甚至不必创建会话即可查看图表中所有操作名称的名称。为此,您只需获取默认图形tf.get_default_graph()
并提取所有操作:.get_operations
。每个操作都有many fields,你需要的是名字。
代码如下:
import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c
for i in tf.get_default_graph().get_operations():
print i.name
【讨论】:
这可行,但由于某种原因,初始模型(我尝试过的唯一模型)在其大多数名称中都有一个:0
,并且上面的i.name
代码没有反映它。这是为什么呢?
AttributeError: module 'tensorflow' has no attribute 'get_default_graph' - 这个代码也适用于 tf2.x 吗?
要在 >= Tf 2.0 中获取图形,不要使用 tf.get_default_graph()。相反,您需要先实例化一个 tf.function 并按如下方式访问 graph 属性:graph = func.get_concrete_function().graph【参考方案5】:
作为嵌套列表推导:
tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]
获取图表中张量名称的函数(默认为默认图表):
def get_names(graph=tf.get_default_graph()):
return [t.name for op in graph.get_operations() for t in op.values()]
在图中获取张量的功能(默认为默认图):
def get_tensors(graph=tf.get_default_graph()):
return [t for op in graph.get_operations() for t in op.values()]
【讨论】:
【参考方案6】:saved_model_cli
是 TF 附带的替代命令行工具,可能在您处理“SavedModel”格式时很有用。来自docs
!saved_model_cli show --dir /tmp/mobilenet/1 --tag_set serve --all
此输出可能有用,例如:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['dense_input'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1280)
name: serving_default_dense_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['dense_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
【讨论】:
以上是关于Tensorflow 中图表中的张量名称列表的主要内容,如果未能解决你的问题,请参考以下文章