转载:tensorflow保存训练后的模型
Posted 佟学强
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了转载:tensorflow保存训练后的模型相关的知识,希望对你有一定的参考价值。
训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存。如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值。建议可以使用Saver类保存和加载模型的结果。
1、使用tf.train.Saver.save()方法保存模型
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix=‘meta‘, write_meta_graph=True, write_state=True)
- sess: 用于保存变量操作的会话。
- save_path: String类型,用于指定训练结果的保存路径。
- global_step: 如果提供的话,这个数字会添加到save_path后面,用于构建checkpoint文件。这个参数有助于我们区分不同训练阶段的结果。
2、使用tf.train.Saver.restore方法价值模型
tf.train.Saver.restore(sess, save_path)
- sess: 用于加载变量操作的会话。
- save_path: 同保存模型是用到的的save_path参数。
下面通过一个代码演示这两个函数的使用方法
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ‘‘
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + ‘model.ckpt‘, global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
以上是关于转载:tensorflow保存训练后的模型的主要内容,如果未能解决你的问题,请参考以下文章
如何在 TensorFlow 2 中保存/加载模型的一部分?
如何将 TensorFlow 模型转换为 TFLite 模型