tensorflow的断点续训

Posted sienbo

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow的断点续训相关的知识,希望对你有一定的参考价值。

tensorflow的断点续训

2019-09-07

顾名思义,断点续训的意思是因为某些原因模型还没有训练完成就被中断,下一次训练可以在上一次训练的基础上继续训练而不用从头开始;这种方式对于你那些训练时间很长的模型来说非常友好。

如果要进行断点续训,那么得满足两个条件:

(1)本地保存了模型训练中的快照;(即断点数据保存)

(2)可以通过读取快照恢复模型训练的现场环境。(断点数据恢复)

这两个操作都用到了tensorflow中的train.Saver类。

 

1.tensorflow.trainn.Saver类

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)
这里不对所有参数进行介绍,只介绍常用的参数
max_to_keep:允许保存的模型的个数,默认为5;当保存的个数超过5时,自动删除最旧的模型,以保证最多同时存在5个模型;如果设置为0或者None,则会对所有训练中的模型进行保存,但是这样除了多占硬盘外没什么意义。
其他的参数一般就使用默认值就可以了。
saver = tf.train.Saver(max_to_keep=10)

有机会再补充其他参数的用法。

2.断点数据的保存

使用saver对象的save方法即可保存模型:

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix=meta,
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False,
    save_debug_info=False
)

常用参数:

sess:需要保存的会话,一般就是我们程序中的sess;

save_path:保存模型的文件路径以及名称,例如“ckpt/my_model”,注意如果要保存在ckpt文件夹下,那么需要在ckpt后面加个斜杠/;

global_step:训练次数,saver会自动将这个值加入到保存的文件名字中。

saver.save(sess,"my_model",global_step=1)
saver.save(sess,"my_model",global_step=100)
saver.save(sess,"ckpt/my_model",global_step=1)

其中1,2,3行代码分别会:

1:在代码的路径下生成名为“my_model_1文件”;

2:在代码的路径下生成名为“my_model_100文件”;

3:在ckpt文件夹下生成名为“my_model_1文件”。

 最常见的用法:

for epoch in range(n_iter):
    ‘‘‘
    training process
    ‘‘‘
    saver.save(sess,ckpt_dir+"model_name",global_step=epoch)

其中ckpt_dir是断点数据存放的路径。

 

3.断点数据的恢复

先建立一个与之前相同的模型;然后再检查有没有断点数据,如果有,则进行恢复。

ckpt_dir = "ckpt/"
#创建Saver对象
saver = tf.train.Saver()
#如果有断点文件,读取最近的断点文件
ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess,ckpt)

不需要提供模型的名字,tf.train.latest_checkpoint(ckpt_dir)会去ckpt_dir文件夹中自动寻找最新的模型文件。

以上是关于tensorflow的断点续训的主要内容,如果未能解决你的问题,请参考以下文章

pytorch实现断点续训

Pytorch分布式训练与断点续训

Pytorch分布式训练与断点续训

第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像

第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像

PyTorch保存模型断点以及加载断点继续训练