Tensorflow:保存和恢复模型参数
Posted
技术标签:
【中文标题】Tensorflow:保存和恢复模型参数【英文标题】:Tensorflow: Saving and restoring the model parameters 【发布时间】:2016-08-16 17:19:25 【问题描述】:我是 TensorFlow 的初学者,目前正在训练 CNN。
我正在使用 Saver 来保存模型使用的参数,但我担心这本身是否会存储模型使用的所有变量,并且足以将这些值恢复到重新运行程序以在经过训练的网络上执行分类/测试。
让我们看一下 TensorFlow 给出的著名示例 MNIST。
在示例中,我们有一堆卷积块,所有这些块都有权重和偏差变量,这些变量在程序运行时会被初始化。
W_conv1 = init_weight([5,5,1,32])
b_conv1 = init_bias([32])
在处理了几个层之后,我们创建一个会话,并初始化所有添加到图中的变量。
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
这里,是否可以将saver.save代码注释掉,训练结束后用saver.restore(sess,file_path)替换,以便将权重、偏差等参数恢复回图?应该是这样吗?
for i in range(1000):
...
if i%500 == 0:
saver.save(sess,"model%d.cpkt"%(i))
我目前正在对大型数据集进行培训,因此终止和重新开始培训是浪费时间和资源,因此我请求有人在我开始培训之前澄清一下。
【问题讨论】:
有点不清楚你在问什么。 “评论 saver.save 代码,并用 saver.restore(sess,file_path) 替换它”您不想存储您的训练值并从以前的训练中重置(通过恢复)? “所以终止,然后重新开始训练是浪费”。这意味着您想在完成所有训练后保存一次模型? @Sung Kim:你的后一个问题的答案是肯定的。我不打算使用存储的值重新开始训练,而是在完成训练后简单地保存模型一次。因为在Matlab中,这很简单,而且其实我是第一次用Python编程,用TensorFlow,所以不知道有没有其他优雅的保存参数的方法。 【参考方案1】:如果您只想保存一次最终结果,您可以这样做:
with tf.Session() as sess:
for i in range(1000):
...
path = saver.save(sess, "model.ckpt") # out of the loop
print "Saved:", path
在其他程序中,您可以使用从 saver.save 返回的路径加载模型以进行预测或其他操作。您可以在https://github.com/sugyan/tensorflow-mnist 看到一些示例。
【讨论】:
【参考方案2】:基于here 和 Sung Kim 解决方案中的解释,我为这个问题编写了一个非常简单的模型。基本上以这种方式,您需要从同一类创建一个对象并从保护程序中恢复其变量。您可以在here找到此解决方案的示例。
【讨论】:
以上是关于Tensorflow:保存和恢复模型参数的主要内容,如果未能解决你的问题,请参考以下文章
tensorflow saver 保存和恢复指定 tensor
tensorflow 1.0 学习:模型的保存与恢复(Saver)
Tensorflow - Tutorial : Variables的保存与恢复