在 Tensorflow 中,获取图中所有张量的名称

Posted

技术标签:

【中文标题】在 Tensorflow 中,获取图中所有张量的名称【英文标题】:In Tensorflow, get the names of all the Tensors in a graph 【发布时间】:2016-08-21 09:42:28 【问题描述】:

我正在使用Tensorflowskflow 创建神经网络;出于某种原因,我想获取给定输入的一些内部张量的值,所以我使用myClassifier.get_layer_value(input, "tensorName")myClassifier 作为skflow.estimators.TensorFlowEstimator

但是,我发现很难找到张量名称的正确语法,即使知道它的名称(而且我对运算和张量感到困惑),所以我使用 tensorboard 来绘制图形并查找名字。

有没有办法在不使用张量板的情况下枚举图中的所有张量?

【问题讨论】:

【参考方案1】:

你可以的

[n.name for n in tf.get_default_graph().as_graph_def().node]

另外,如果您在 IPython 笔记本中进行原型设计,您可以直接在笔记本中显示图形,请参阅 Alexander's Deep Dream 中的 show_graph 函数 notebook

【讨论】:

你可以过滤这个例如通过在理解的末尾添加 if "Variable" in n.op 来变量。 如果知道名字有没有办法获取特定节点? 了解更多关于图节点的信息:tensorflow.org/extend/tool_developers/#nodes 上面的命令产生所有操作/节点的名称。要获取所有张量的名称,请执行以下操作: tensors_per_node = [node.values() for node in graph.get_operations()] tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]【参考方案2】:

我会试着总结一下答案:

要获取图表中的所有节点(输入tensorflow.core.framework.node_def_pb2.NodeDef

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

要获取图表中的所有操作(输入tensorflow.python.framework.ops.Operation

all_ops = tf.get_default_graph().get_operations()

要获取图表中的所有变量(输入tensorflow.python.ops.resource_variable_ops.ResourceVariable

all_vars = tf.global_variables()

要获取图表中的所有张量(输入tensorflow.python.framework.ops.Tensor

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

要获取图表中的所有占位符(输入tensorflow.python.framework.ops.Tensor

all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]

张量流 2

要在 Tensorflow 2 中获取图形,而不是 tf.get_default_graph(),您需要先实例化 tf.function 并访问 graph 属性,例如:

graph = func.get_concrete_function().graph

其中functf.function

【讨论】:

请注意那个 TF2 版本!【参考方案3】:

通过使用get_operations,有一种方法可以比雅罗斯拉夫的回答稍微快一点。这是一个简单的例子:

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))

【讨论】:

您无法使用tf.get_operations() 获取张量。只有你能得到的操作。【参考方案4】:

tf.all_variables()可以得到你想要的信息。

另外,this commit 今天在 TensorFlow Learn 中创建,它在估算器中提供了一个函数 get_variable_names,您可以使用它轻松检索所有变量名称。

【讨论】:

此功能已弃用 ... 其后继者是tf.global_variables() 这只会获取变量,而不是张量。 在 Tensorflow 1.9.0 中显示 all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02 module 'tensorflow' has no attribute 'all_variables'【参考方案5】:

我认为这也可以:

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

但是对比萨尔瓦多和雅罗斯拉夫的答案,我不知道哪个更好。

【讨论】:

这个使用了从张量流对象检测 API 中使用的 frozen_inference_graph.pb 文件导入的图。谢谢【参考方案6】:

接受的答案只为您提供带有名称的字符串列表。我更喜欢一种不同的方法,它可以让你(几乎)直接访问张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples 现在包含每个张量,每个张量都在一个元组中。您还可以对其进行调整以直接获取张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

【讨论】:

【参考方案7】:

由于 OP 要求的是张量列表而不是操作/节点列表,因此代码应该略有不同:

graph = tf.get_default_graph()    
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

【讨论】:

【参考方案8】:

以前的答案很好,我只想分享一个我写的从图表中选择张量的实用函数:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
    """Selects nodes' names in the graph if:
    - The name contains all items in and_conds
    - OR/AND depending on op
    - The name contains any item in or_conds

    Condition starting with a "!" are negated.
    Returns all ops if no optional arguments is given.

    Args:
        graph (tf.Graph): The graph containing sought tensors
        and_conds (list(str)), optional): Defaults to None.
            "and" conditions
        op (str, optional): Defaults to 'and'. 
            How to link the and_conds and or_conds:
            with an 'and' or an 'or'
        or_conds (list(str), optional): Defaults to None.
            "or conditions"

    Returns:
        list(str): list of relevant tensor names
    """
    assert op in 'and', 'or'

    if and_conds is None:
        and_conds = ['']
    if or_conds is None:
        or_conds = ['']

    node_names = [n.name for n in graph.as_graph_def().node]

    ands = 
        n for n in node_names
        if all(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in and_conds
        )

    ors = 
        n for n in node_names
        if any(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in or_conds
        )

    if op == 'and':
        return [
            n for n in node_names
            if n in ands.intersection(ors)
        ]
    elif op == 'or':
        return [
            n for n in node_names
            if n in ands.union(ors)
        ]

所以如果你有一个带有操作的图表:

['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

然后运行

get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

返回:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

【讨论】:

【参考方案9】:

以下解决方案适用于 TensorFlow 2.3 -

def load_pb(path_to_pb):
    with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
tf_graph = load_pb(MODEL_FILE)
sess = tf.compat.v1.Session(graph=tf_graph)

# Show tensor names in graph
for op in tf_graph.get_operations():
    print(op.values())

MODEL_FILE 是您的冻结图的路径。

取自here。

【讨论】:

【参考方案10】:

这对我有用:

for n in tf.get_default_graph().as_graph_def().node:
    print('\n',n)

【讨论】:

以上是关于在 Tensorflow 中,获取图中所有张量的名称的主要内容,如果未能解决你的问题,请参考以下文章

在 Tensorflow.js 中获取张量中项目的值

《30天吃掉那只 TensorFlow2.0》 2-1 张量数据结构

将张量数据裁剪到包围体

张量流图中的梯度是不是计算不正确?

在 TensorFlow 中,如何使用 python 从张量中获取非零值及其索引?

如何在图构建时获取张量的维度(在 TensorFlow 中)?