tensorflow-训练检查点tf.train.Saver

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow-训练检查点tf.train.Saver相关的知识,希望对你有一定的参考价值。


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar(‘myvar‘,my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph)
    sess.run(init)
    #---0
    step,var,summary=sess.run([my_step,my_var,merged_summaries])
    writer.add_summary(summary,global_step=step)
    print step,var
    saver=tf.train.Saver()
    #1-49
    for i in xrange(1,50):
        sess.run(addop)
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var
        if i%5==0:
            saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i)
    saver.save(sess,‘./myvar-model/myvar-model‘,global_step=49)

    writer.flush()
    writer.close()

38 0.0512373
39 0.04996785
40 0.048759546
41 0.04760808
42 0.04650955
43 0.045460388
44 0.04445735
45 0.04349747
46 0.042578023
47 0.041696515
48 0.040850647
49 0.04003831

保存数据流图的变量到二进制检查点文件。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
import os
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32,trainable=False)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar(‘myvar‘,my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph)
    sess.run(init)

    saver=tf.train.Saver()

    #如果之前保存了检查点文件,则恢复模型后,继续
    init_step=0
    ckpt=tf.train.get_checkpoint_state(os.getcwd()+‘/myvar-model‘)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step=int(ckpt.model_checkpoint_path.rsplit(‘-‘,1)[1])
        print "读取检查点文件..."
    for i in xrange(init_step,100):
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var,init_step
        if i%5==0 and i<=50:
            print "保存检查点文件"
            saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i)
        sess.run(addop)

    writer.flush()
    writer.close()

上面代码跑第一次时,检查点文件被保存,跑第二次开始,检查点文件将被读取,循环次数从step=50开始。

跑第二次时

读取检查点文件...
50 0.03925755 50
保存检查点文件
51 0.038506564 50
52 0.037783686 50
53 0.03708737 50
54 0.036416177 50
55 0.035768777 50
56 0.03514393 50
...
...
...
93 0.021334965 50
94 0.02111056 50
95 0.02089082 50
96 0.0206756 50
97 0.020464761 50
98 0.020258171 50
99 0.020055704 50

以上是关于tensorflow-训练检查点tf.train.Saver的主要内容,如果未能解决你的问题,请参考以下文章

tf.train.Saver()-tensorflow中模型的保存及读取

在 Google Colab 中保存 TensorFlow 检查点

tf.train.MonitoredTrainingSession 中的 tf.train.CheckpointSaverHook 是不是会在检查点或异步完成时阻止训练?

tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

TensorFlow 中的 tf.train.exponential_decay() 指数衰减法