Tensorflow模型的 暂存 恢复 微调 保存 加载
Posted jhc888007
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow模型的 暂存 恢复 微调 保存 加载相关的知识,希望对你有一定的参考价值。
- 暂存模型(*.index为参数名称,*.meta为模型图,*.data*为参数)
tf.reset_default_graph() weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights") biases = tf.Variable(0, name="biases") saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run([weights])) saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME)) sess.close()
- 暂存模型(同一模型多次保存可以不保存模型图节省时间)
tf.reset_default_graph() weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights") biases = tf.Variable(0, name="biases") saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run([weights])) saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME)) time.sleep(5) saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False) time.sleep(5) saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False) sess.close()
- 恢复模型(手动生成网络则不需要*.meta文件)
tf.reset_default_graph() weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights") biases = tf.Variable(0, name="biases") saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME)) print(sess.run([weights])) sess.close()
- 恢复模型(从*.meta文件生成网络)
tf.reset_default_graph() saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME)) sess = tf.Session() saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME)) print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")])) sess.close()
- 恢复模型(可以在一个文件夹下保存多次模型,checkpoint文件会自动记录所有模型名称和最后一次记录模型名称)
tf.reset_default_graph() weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights") biases = tf.Variable(0, name="biases") saver = tf.train.Saver() sess = tf.Session() ckpt = tf.train.get_checkpoint_state(MODEL_DIR) saver.restore(sess, ckpt.model_checkpoint_path) print(sess.run([weights])) sess.close()
- 微调模型(恢复之前训练模型的部分参数,加上新参数,继续训练)
def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True): ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() if not include_global_step: ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None) vars_in_ckpt = for variable_name, variable in sorted(variables.items()): if variable_name in ckpt_vars_to_shape_map: if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list(): vars_in_ckpt[variable_name] = variable return vars_in_ckpt tf.reset_default_graph() weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights") biases = tf.Variable(0, name="biases") other_weights = tf.Variable(tf.zeros([10, 10])) variables_to_init = tf.global_variables() variables_to_init_dict = var.op.name: var for var in variables_to_init available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict, "%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False) tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run([weights])) sess.close()
- 保存模型(二进制模型)
from tensorflow.python.framework.graph_util import convert_variables_to_constants tf.reset_default_graph() saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME)) sess = tf.Session() saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME)) graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=[‘weights‘]) with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output: output.write(graph_out.SerializeToString()) sess.close()
- 加载模型(二进制模型)
tf.reset_default_graph() sess = tf.Session() with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),‘rb‘) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def,name=‘‘) sess.run(tf.global_variables_initializer()) print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")])) sess.close()
参考文献:
https://blog.csdn.net/loveliuzz/article/details/81661875
https://www.cnblogs.com/bbird/p/9951943.html
https://blog.csdn.net/gzj_1101/article/details/80299610
以上是关于Tensorflow模型的 暂存 恢复 微调 保存 加载的主要内容,如果未能解决你的问题,请参考以下文章