Tensorflow:成功恢复检查点后丢失重置

Posted

技术标签:

【中文标题】Tensorflow:成功恢复检查点后丢失重置【英文标题】:Tensorflow: loss resets after successfully restored checkpoint 【发布时间】:2017-05-10 18:32:10 【问题描述】:

保存或恢复时没有错误。权重似乎已正确恢复。

我正在尝试按照karpathy/min-char-rnn.py、sherjilozair/char-rnn-tensorflow 和Tensorflow RNN tutorial 构建我自己的最小字符级RNN。我的脚本似乎按预期工作,除非我尝试恢复/恢复训练。

如果我重新启动脚本并从检查点恢复,然后恢复训练,则损失总是会恢复,就好像没有检查点一样(尽管权重已正确恢复)。 但是,在脚本执行过程中,如果我重置图表、启动新会话并恢复,则可以继续按预期将损失最小化。

我已尝试在我的台式机(使用 GPU)和笔记本电脑(仅 CPU)上运行此程序,两者都在装有 Tensorflow 0.12 的 Windows 上运行。

下面是我的代码,我这里上传了代码+数据+控制台输出: https://gist.github.com/dk1027/777c3da7ba1ff7739b5f5e89491bef73

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import rnn_cell

class model_input:

    def __init__(self,data_path, batch_size, steps):
        self.batch_idx = 0
        self.data_path = data_path
        self.steps = steps
        self.batch_size = batch_size
        data = open(self.data_path).read()
        data_size = len(data)
        self.vocab = set(data)
        self.vocab_size = len(self.vocab)
        self.vocab_to_idx = v:i for i,v in enumerate(self.vocab)
        self.idx_to_vocab = i:v for i,v in enumerate(self.vocab)
        c = self.batch_size * self.steps
        #Offset by 1 character because we want to predict the next character
        _data_as_idx = np.asarray([self.vocab_to_idx[v] for v in data], dtype=np.int32)
        self.X = _data_as_idx[:-1]
        self.Y = _data_as_idx[1:]

    def reset(self):
        self.batch_idx = 0

    def next_batch2(self):
        i = self.batch_idx
        j = self.batch_idx + self.batch_size * self.steps

        if j >= self.X.shape[0]:
            i = 0
            j = self.batch_size * self.steps
            self.batch_idx = 0

        #print("next_batch: (%s,%s)" %(i,j))
        x = self.X[i:j]
        x = x.reshape(-1,self.steps)

        _xlen = x.shape[0]
        _y = self.Y[i:j]
        _y = _y.reshape(-1,self.steps)
        self.batch_idx += 1

        return x, _y

    def toIdx(self, s):
        res = []
        for _s in s:
            res.append(self.vocab_to_idx[_s])
        return res

    def toStr(self, idx):
        s = ''
        for i in idx:
            s += self.idx_to_vocab[i]
        return s

class Config():
    def __init__(self):
        # Parameters
        self.learning_rate = 0.001
        self.training_iters = 10000
        self.batch_size = 20
        self.display_step = 200
        self.max_epoch = 1
        # Network Parameters
        self.n_input = 1 # 1 character input
        self.n_steps = 25 # sequence length
        self.n_hidden = 128 # hidden layer num of features
        self.n_rnn_layers = 2
        # To be set later
        self.vocab_size = None

# Train
def Train(sess, model, data, config, saver):
    init_state = sess.run(model.initial_state)
    data.reset()
    epoch = 0
    while epoch < config.max_epoch:
        # Keep training until reach max iterations
        step = 0
        while step * config.batch_size < config.training_iters:
            # Run optimization op (backprop)
            fetch_dict = 
                "cost": model.cost,
                "final_state": model.final_state,
                "op" : model.train_op
            
            feed_dict = 
            for i, (c, h) in enumerate(model.initial_state):
                feed_dict[c] = init_state[i].c
                feed_dict[h] = init_state[i].h
            batch_x, batch_y = data.next_batch2()
            feed_dict[model.x]=batch_x
            feed_dict[model.y]=batch_y
            fetches = sess.run(fetch_dict, feed_dict=feed_dict)

            if (step % config.display_step) == 0:
                print("Iter " + str(step*config.batch_size) + ", Minibatch Loss=:.7f".format(fetches["cost"]))
            step += 1
            if (step*config.batch_size % 5000) == 0:
                sp = saver.save(sess, config.save_path + "model.ckpt", global_step = step * config.batch_size + epoch * config.training_iters)
                print("Saved to %s" % sp)
        sp = saver.save(sess, config.save_path + "model.ckpt", global_step = step * config.batch_size + epoch * config.training_iters)
        print("Saved to %s" % sp)
        epoch += 1

    print("Optimization Finished!")


class Model():
    def __init__(self, config):
        self.config = config

        lstm_cell = rnn_cell.BasicLSTMCell(config.n_hidden, state_is_tuple=True)

        self.cell = rnn_cell.MultiRNNCell([lstm_cell] * config.n_rnn_layers, state_is_tuple=True)

        self.x = tf.placeholder(tf.int32, [config.batch_size, config.n_steps])
        self.y = tf.placeholder(tf.int32, [config.batch_size, config.n_steps]) 
        self.initial_state = self.cell.zero_state(config.batch_size, tf.float32)

        with tf.device("/cpu:0"):
            embedding = tf.get_variable("embedding", [config.vocab_size, config.n_hidden], dtype=tf.float32)
            inputs = tf.nn.embedding_lookup(embedding, self.x)
        outputs = []
        state = self.initial_state
        with tf.variable_scope('rnn'):
            softmax_w = tf.get_variable("softmax_w", [config.n_hidden, config.vocab_size])
            softmax_b = tf.get_variable("softmax_b", [config.vocab_size])

            for time_step in range(config.n_steps):
                if time_step > 0: tf.get_variable_scope().reuse_variables()
                (cell_output, state) = self.cell(inputs[:, time_step, :], state)
                outputs.append(cell_output)

        output = tf.reshape(tf.concat(1, outputs), [-1, config.n_hidden])
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        loss = tf.nn.seq2seq.sequence_loss_by_example(
            [self.logits],
            [self.y],
            [tf.ones([config.batch_size * config.n_steps], dtype=tf.float32)],
            name="seq2seq")

        self.cost = tf.reduce_sum(loss) / config.batch_size
        self.final_state = state

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),5)
        optimizer = tf.train.AdamOptimizer(config.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def main():
    # Read input data
    data_path = "1sonnet.txt"
    save_path = "./save/"

    config = Config()
    data = model_input(data_path, config.batch_size, config.n_steps)
    config.vocab_size = data.vocab_size
    config.data_path = data_path
    config.save_path = save_path

    train_model = Model(config)
    print("Model defined.")

    bReproProblem = True
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(save_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("restored from %s" % ckpt.model_checkpoint_path)

        Train(sess, train_model, data, config, saver)


    if bReproProblem:
        tf.reset_default_graph() #reset everything
        data.reset()
        train_model2 = Model(config)
        print("Starting a new session, restore from checkpoint, and train again")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver2 = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(save_path)
            if ckpt and ckpt.model_checkpoint_path:
                saver2.restore(sess, ckpt.model_checkpoint_path)
                print("restored from %s" % ckpt.model_checkpoint_path)

            Train(sess, train_model2, data, config, saver2)


if __name__ == '__main__':
    main()

【问题讨论】:

【参考方案1】:

TL;DR

请确保您的标签在每次运行代码时相同,尤其是对于那些使用列表索引作为标签的人。

详情请见this question。

如果您使用列表索引作为标签,请对数据进行排序或将索引保存到磁盘。使用:

labels = sorted(set(data))

而不是

labels = set(data))

一般建议

在 Python 实现中,有一些方法,如 set()os.listdir(),返回一个未排序的集合。换句话说,每次运行时项目的索引可能不同。

对于set(),Python use a random method 构建一个set。对于os.listdir()、it doesn't promise the order of the returned list。因此,对于健壮的代码,建议对您的数据集使用sorted()

关于您的问题

data_size = len(data)
self.vocab = set(data)
self.vocab_size = len(self.vocab)
self.vocab_to_idx = v:i for i,v in enumerate(self.vocab)
self.idx_to_vocab = i:v for i,v in enumerate(self.vocab)

这可能是由您构建标签的方式引起的。 vocab_to_idx 可能在您每次运行代码时都不同。

只需添加一个sorted()

data_size = len(data)
self.vocab = sorted(set(data))
self.vocab_size = len(self.vocab)
self.vocab_to_idx = v:i for i,v in enumerate(self.vocab)
self.idx_to_vocab = i:v for i,v in enumerate(self.vocab)

【讨论】:

以上是关于Tensorflow:成功恢复检查点后丢失重置的主要内容,如果未能解决你的问题,请参考以下文章

在 Windows 重置后,我的 Pycharm Projects 文件夹所在的用户文件夹丢失并且无法恢复文件

使用WP重置插件后,WordPress仪表板选项消失了

Swift:Facebook SDK,成功登录后检查丢失的权限

回顾我的设计:网站密码重置工具

系统升级/重装导致金蝶数据库账套丢失找回

服务器数据恢复成功率预测