TensorFlow:有没有办法将冻结图转换为检查点模型?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow:有没有办法将冻结图转换为检查点模型?相关的知识,希望对你有一定的参考价值。

可以将检查点模型转换为冻结图(.ckpt文件到.pb文件)。但是,有没有一种将pb文件再次转换为检查点文件的反向方法?

我想它需要将常量转换回变量 - 有没有办法将正确的常量识别为变量并将它们恢复为检查点模型?

目前支持将变量转换为常量:https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

但不是相反。

这里提出了类似的问题:Tensorflow: Convert constant tensor from pre-trained Vgg model to variable

但该解决方案依赖于使用ckpt模型来恢复权重变量。有没有办法从PB文件而不是检查点文件中恢复权重变量?这对于重量修剪可能很有用。

答案

如果您有构建网络的源代码,则可以相对简单地完成,因为冻结图方法没有改变Convolutions / Fully connected的名称,因此您基本上可以研究图形并将常量操作与其变量相匹配匹配并只使用常量值加载变量。

如果你没有构建网络的代码,它仍然可以完成,但它并不是直接的。

例如,您可以搜索图中的所有节点并查找类型为Constant的操作,然后在找到类型为Constant的所有操作后,您可以看到操作是否连接到Convolution / Fully connected,例如..(或者您可以只是转换它依赖于你的所有常量)。

在找到要转换为变量的常量后,可以将变量添加到保存常量值的图形中,然后使用Tensorflow graph editor重新连接const操作与变量之间的连接(使用reroute_ts方法)。

完成后你可以保存你的图形,当你再次加载它时你会得到你的变量(但要注意常量仍然会保留在你的图形中,但是它们可以通过graph-transform工具进行优化)

另一答案

如果您有构建网络的源代码,则可以相对简单地完成,因为冻结图方法没有改变Convolutions / Fully connected的名称,因此您基本上可以研究图形并将常量操作与其变量相匹配匹配并只使用常量值加载变量。 - 来自Almog David

感谢@ Almog David上面的优秀答案;我正面临着同样的情况

  • 我有frozen_inference_graph.pb但没有检查站;
  • 我有生成frozen_inference_graph.pb的源代码,但我不知道参数。

以下是解决困境的三个步骤。

1.从frozen_inference_graph.pb获取节点名称和值对

import tensorflow as tf
from tensorflow.python.framework import tensor_util

def get_node_values(old_graph_path):
    old_graph = tf.Graph()
    with old_graph.as_default():
        old_graph_def = tf.GraphDef()
        with tf.gfile.GFile(old_graph_path, "rb") as fid:
            serialized_graph = fid.read()
            old_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(old_graph_def, name='')

    old_sess = tf.Session(graph=old_graph)

    # get all the nodes from the graph def
    nodes = old_sess.graph.as_graph_def().node

    value_dict = 
    for node in nodes:
        value = node.attr['value'].tensor
        try:
            # get name and value (numpy array) from tensor 
            value_dict[node.name] = tensor_util.MakeNdarray(value) 
        except:
            # some tensor doesn't have value; for example np.squeeze
            # just ignore it 
            pass
    return value_dict

value_dict = get_node_values("frozen_inference_graph.pb")

2.使用现有代码创建新图表;调整模型参数,直到新图形中的所有节点都出现在value_dict

new_graph = tf.Graph()
with new_graph.as_default():
    tf.create_global_step()
    #existing code 
    # ...
    # ...
    # ...

    model_variables = tf.model_variables()
    unseen_variables = set(model_variable.name[:-2] for model_variable in model_variables) - set(value_dict.keys())
    print  ("\n".join(sorted(list(unseen_variables))))

3.将值分配给变量并保存到检查点(或保存到图表)

new_graph_path = "model.ckpt"
saver = tf.train.Saver(model_variables)

assign_ops = []
for variable in model_variables:
    print ("Assigning", variable.name[:-2])
    # variable names have ":0" but constant names doesn't have.
    value = value_dict[variable.name[:-2]]
    assign_ops.append(variable.assign(value))

sess =session.Session(graph = new_graph)
sess.run(tf.global_variables_initializer())
sess.run(assign_ops)
saver.save(sess, new_graph_path+"model.ckpt")

这是我能想到解决这个问题的唯一方法。但是,它仍然存在一些缺点:如果你重新加载模型检查点,你会发现(沿着所有有用的变量)很多不需要的assign变量,如Assign_700/value。这是不可避免的,看起来很难看。如果您有更好的建议,请随时发表评论。谢谢。

另一答案

有一种方法可以通过Graph Editor在TensorFlow中将常量转换回可训练的变量。但是,您需要指定要转换的节点,因为我不确定是否有办法以健壮的方式自动检测到这些节点。

以下是步骤:

Step 1: Load frozen graph

我们将.pb文件加载到图形对象中。

import tensorflow as tf

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb('frozen_graph.pb')

Step 2: Find constants that need conversion

以下是列出图表中节点名称的两种方法:

  • 使用this script打印它们
  • print([n.name for n in tf_graph.as_graph_def().node])

您想要转换的节点很可能被命名为“Const”。可以肯定的是,最好在Netron中加载图形,以查看哪些张量存储了可训练的权重。通常,可以安全地假设所有const节点都是变量。

识别出这些节点后,让我们将它们的名称存储到一个列表中:

to_convert = [...] # names of tensors to convert

Step 3: Convert constants to variables

运行此代码以转换指定的常量。它基本上为每个常量创建相应的变量,并使用GraphEditor从图中取消常量,并挂钩变量。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in to_convert:
        tensor = g.get_tensor_by_name(':0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

Step 4: Save result as .ckpt

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = tf.train.Saver().save(sess, 'model.ckpt')
        print("Model saved in path: %s" % save_path)

而且中提琴!你应该在这一点上完成:)我能够自己完成这项工作,并确认模型权重得以保留 - 唯一的区别是图表现在可以训练。如果有任何问题,请告诉我。

以上是关于TensorFlow:有没有办法将冻结图转换为检查点模型?的主要内容,如果未能解决你的问题,请参考以下文章

如何将冻结图转换为 TensorFlow lite

将冻结图转换为 TRT 图时 Jetson Nano 上的 TensorRT 错误

在张量流中将 SSD 转换为冻结图。必须使用哪些输出节点名称?

从 Keras 构建 TensoRRT 引擎时出错

将 .tflite 转换为 .pb

将 TensorFlow 模型从 Object Detection API 转换为 uff