TensorFlow:冻结模型似乎只存储输出节点?

Posted

技术标签:

【中文标题】TensorFlow:冻结模型似乎只存储输出节点?【英文标题】:TensorFlow: Freezing a model seems to only store the output nodes? 【发布时间】:2017-09-16 18:48:12 【问题描述】:

我正在尝试冻结我学习的 Tensorflow 模型。这是取自教程Deep MNIST for Experts

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def conv2d(x_vector, w_matrix):
    return tf.nn.conv2d(x_vector, w_matrix, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x_vector):
    return tf.nn.max_pool(x_vector, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


output_graph_name = 'my_graph.pb'

# Create model

label_count = 12

x = tf.placeholder(tf.float32, shape=[None, 1024], name="x")
y_ = tf.placeholder(tf.float32, shape=[None, label_count], name="y_")

w_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 32, 32, 1])

h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

w_fc1 = weight_variable([8 * 8 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 8 * 8 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

w_fc2 = weight_variable([1024, label_count])
b_fc2 = bias_variable([label_count])

y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # ... train, etc etc ...
    # train_step.run(feed_dict=x: train_set[0],
    #                           y_: train_set[1],
    #                           keep_prob: 0.5)

    # Save the variables to disk.
    save_path = saver.save(sess, "my_model.ckpt")

    # Save graph
    tf.train.write_graph(sess.graph_def, '.', output_graph_name, as_text=False)

然后我尝试冻结我的模型

from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph(input_graph=output_graph_name,
                          input_saver="",
                          input_binary=True,
                          input_checkpoint="my_model",
                          output_node_names="y_",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph="frozen_graph.pb",
                          clear_devices=True,
                          initializer_nodes="")

现在,冻结图仅包含 y_ 占位符,而不包含整个关联网络。 graph_util.extract_sub_graph 仅提取 y_。为什么会这样?如何冻结整个网络?我应该使用conv_y 而不是y_

顶部“占位符”节点用于y_,底部“占位符”节点用于x

【问题讨论】:

freeze_graph 不是 tf 函数!这个从哪里来?尝试添加 tf.identity(y_conv, name='my_output') 并将此广告用作 output_node_name。大多数图形冻结仅提取图形的必要部分。从y_y_ 似乎是你的子图。 from tensorflow.python.tools import freeze_graph。第二部分,我不明白只有y_ 是如何相关的。我希望the entire graph 出现,从x(输入节点)到y_(输出节点) y_ 是一个输入节点,也是训练期间计算的开始。预测y_conv 和成本节点cross_entropy 是输出节点。所以我的答案是一样的:添加 tf.identity(y_conv, name='my_output') 并使用output_node_names='my_output' 确实如此!我误解了计算图。您能否将您的评论升级为答案以便我接受? 【参考方案1】:

为了再次说服您y_ 确实不是输出节点,请添加以下代码:

# dump graph
def childs(t, d=0):
    print '-' * d, t.name
    for child in t.op.inputs:
        childs(child, d + 1)
childs(accuracy)

输出是

Mean_1:0
- Cast_1:0
-- Equal:0
--- ArgMax:0
---- add_3:0
----- MatMul_1:0
------ dropout/mul:0
------- dropout/div:0
-------- Relu_2:0
--------- add_2:0
---------- MatMul:0
----------- Reshape_1:0
------------ MaxPool_1:0
------------- Relu_1:0
-------------- add_1:0
--------------- Conv2D_1:0
---------------- MaxPool:0
----------------- Relu:0
------------------ add:0
------------------- Conv2D:0
-------------------- Reshape:0
--------------------- x:0
--------------------- Reshape/shape:0
-------------------- Variable/read:0
--------------------- Variable:0
------------------- Variable_1/read:0
-------------------- Variable_1:0
---------------- Variable_2/read:0
----------------- Variable_2:0
--------------- Variable_3/read:0
---------------- Variable_3:0
------------ Reshape_1/shape:0
----------- Variable_4/read:0
------------ Variable_4:0
---------- Variable_5/read:0
----------- Variable_5:0
-------- Placeholder:0
------- dropout/Floor:0
-------- dropout/add:0
--------- Placeholder:0
--------- dropout/random_uniform:0
---------- dropout/random_uniform/mul:0
----------- dropout/random_uniform/RandomUniform:0
------------ dropout/Shape:0
------------- Relu_2:0
-------------- add_2:0
--------------- MatMul:0
---------------- Reshape_1:0
----------------- MaxPool_1:0
------------------ Relu_1:0
------------------- add_1:0
-------------------- Conv2D_1:0
--------------------- MaxPool:0
---------------------- Relu:0
----------------------- add:0
------------------------ Conv2D:0
------------------------- Reshape:0
-------------------------- x:0
-------------------------- Reshape/shape:0
------------------------- Variable/read:0
-------------------------- Variable:0
------------------------ Variable_1/read:0
------------------------- Variable_1:0
--------------------- Variable_2/read:0
---------------------- Variable_2:0
-------------------- Variable_3/read:0
--------------------- Variable_3:0
----------------- Reshape_1/shape:0
---------------- Variable_4/read:0
----------------- Variable_4:0
--------------- Variable_5/read:0
---------------- Variable_5:0
----------- dropout/random_uniform/sub:0
------------ dropout/random_uniform/max:0
------------ dropout/random_uniform/min:0
---------- dropout/random_uniform/min:0
------ Variable_6/read:0
------- Variable_6:0
----- Variable_7/read:0
------ Variable_7:0
---- ArgMax/dimension:0
--- ArgMax_1:0
---- y_:0
---- ArgMax_1/dimension:0
- Const_5:0

解决办法是

tf.identity(y_conv, name='my_output')
freeze_graph.freeze_graph(input_graph=output_graph_name,
                          input_saver="",
                          input_binary=True,
                          input_checkpoint="my_model",
                          output_node_names="my_output",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph="frozen_graph.pb",
                          clear_devices=True,
                          initializer_nodes="")

【讨论】:

以上是关于TensorFlow:冻结模型似乎只存储输出节点?的主要内容,如果未能解决你的问题,请参考以下文章

如何找到 tfslim 输出节点名称

是否可以按宽度扩展冻结神经网络模型的节点?

使用 tf.data.Datasets 冻结 Tensorflow 图时确定输入节点

从 Tensorflow 对象检测 API 动物园模型导出错误的冻结图

Tensorflow 模型:如何从 proto buff 文件中识别输入/输出节点名称?

TensorFlow:有没有办法将冻结图转换为检查点模型?