TensorFlow:冻结模型似乎只存储输出节点?
Posted
技术标签:
【中文标题】TensorFlow:冻结模型似乎只存储输出节点?【英文标题】:TensorFlow: Freezing a model seems to only store the output nodes? 【发布时间】:2017-09-16 18:48:12 【问题描述】:我正在尝试冻结我学习的 Tensorflow 模型。这是取自教程Deep MNIST for Experts:
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x_vector, w_matrix):
return tf.nn.conv2d(x_vector, w_matrix, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x_vector):
return tf.nn.max_pool(x_vector, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
output_graph_name = 'my_graph.pb'
# Create model
label_count = 12
x = tf.placeholder(tf.float32, shape=[None, 1024], name="x")
y_ = tf.placeholder(tf.float32, shape=[None, label_count], name="y_")
w_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 32, 32, 1])
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
w_fc1 = weight_variable([8 * 8 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 8 * 8 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
w_fc2 = weight_variable([1024, label_count])
b_fc2 = bias_variable([label_count])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ... train, etc etc ...
# train_step.run(feed_dict=x: train_set[0],
# y_: train_set[1],
# keep_prob: 0.5)
# Save the variables to disk.
save_path = saver.save(sess, "my_model.ckpt")
# Save graph
tf.train.write_graph(sess.graph_def, '.', output_graph_name, as_text=False)
然后我尝试冻结我的模型
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph(input_graph=output_graph_name,
input_saver="",
input_binary=True,
input_checkpoint="my_model",
output_node_names="y_",
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph="frozen_graph.pb",
clear_devices=True,
initializer_nodes="")
现在,冻结图仅包含 y_
占位符,而不包含整个关联网络。 graph_util.extract_sub_graph
仅提取 y_
。为什么会这样?如何冻结整个网络?我应该使用conv_y
而不是y_
?
顶部“占位符”节点用于y_
,底部“占位符”节点用于x
【问题讨论】:
freeze_graph 不是 tf 函数!这个从哪里来?尝试添加tf.identity(y_conv, name='my_output')
并将此广告用作 output_node_name。大多数图形冻结仅提取图形的必要部分。从y_
到y_
似乎是你的子图。
from tensorflow.python.tools import freeze_graph
。第二部分,我不明白只有y_
是如何相关的。我希望the entire graph 出现,从x
(输入节点)到y_
(输出节点)
y_
是一个输入节点,也是训练期间计算的开始。预测y_conv
和成本节点cross_entropy
是输出节点。所以我的答案是一样的:添加 tf.identity(y_conv, name='my_output') 并使用output_node_names='my_output'
。
确实如此!我误解了计算图。您能否将您的评论升级为答案以便我接受?
【参考方案1】:
为了再次说服您y_
确实不是输出节点,请添加以下代码:
# dump graph
def childs(t, d=0):
print '-' * d, t.name
for child in t.op.inputs:
childs(child, d + 1)
childs(accuracy)
输出是
Mean_1:0
- Cast_1:0
-- Equal:0
--- ArgMax:0
---- add_3:0
----- MatMul_1:0
------ dropout/mul:0
------- dropout/div:0
-------- Relu_2:0
--------- add_2:0
---------- MatMul:0
----------- Reshape_1:0
------------ MaxPool_1:0
------------- Relu_1:0
-------------- add_1:0
--------------- Conv2D_1:0
---------------- MaxPool:0
----------------- Relu:0
------------------ add:0
------------------- Conv2D:0
-------------------- Reshape:0
--------------------- x:0
--------------------- Reshape/shape:0
-------------------- Variable/read:0
--------------------- Variable:0
------------------- Variable_1/read:0
-------------------- Variable_1:0
---------------- Variable_2/read:0
----------------- Variable_2:0
--------------- Variable_3/read:0
---------------- Variable_3:0
------------ Reshape_1/shape:0
----------- Variable_4/read:0
------------ Variable_4:0
---------- Variable_5/read:0
----------- Variable_5:0
-------- Placeholder:0
------- dropout/Floor:0
-------- dropout/add:0
--------- Placeholder:0
--------- dropout/random_uniform:0
---------- dropout/random_uniform/mul:0
----------- dropout/random_uniform/RandomUniform:0
------------ dropout/Shape:0
------------- Relu_2:0
-------------- add_2:0
--------------- MatMul:0
---------------- Reshape_1:0
----------------- MaxPool_1:0
------------------ Relu_1:0
------------------- add_1:0
-------------------- Conv2D_1:0
--------------------- MaxPool:0
---------------------- Relu:0
----------------------- add:0
------------------------ Conv2D:0
------------------------- Reshape:0
-------------------------- x:0
-------------------------- Reshape/shape:0
------------------------- Variable/read:0
-------------------------- Variable:0
------------------------ Variable_1/read:0
------------------------- Variable_1:0
--------------------- Variable_2/read:0
---------------------- Variable_2:0
-------------------- Variable_3/read:0
--------------------- Variable_3:0
----------------- Reshape_1/shape:0
---------------- Variable_4/read:0
----------------- Variable_4:0
--------------- Variable_5/read:0
---------------- Variable_5:0
----------- dropout/random_uniform/sub:0
------------ dropout/random_uniform/max:0
------------ dropout/random_uniform/min:0
---------- dropout/random_uniform/min:0
------ Variable_6/read:0
------- Variable_6:0
----- Variable_7/read:0
------ Variable_7:0
---- ArgMax/dimension:0
--- ArgMax_1:0
---- y_:0
---- ArgMax_1/dimension:0
- Const_5:0
解决办法是
tf.identity(y_conv, name='my_output')
freeze_graph.freeze_graph(input_graph=output_graph_name,
input_saver="",
input_binary=True,
input_checkpoint="my_model",
output_node_names="my_output",
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph="frozen_graph.pb",
clear_devices=True,
initializer_nodes="")
【讨论】:
以上是关于TensorFlow:冻结模型似乎只存储输出节点?的主要内容,如果未能解决你的问题,请参考以下文章
使用 tf.data.Datasets 冻结 Tensorflow 图时确定输入节点
从 Tensorflow 对象检测 API 动物园模型导出错误的冻结图