Python TensorFlow:如何使用优化器和 import_meta_graph 重新开始训练?

Posted

技术标签:

【中文标题】Python TensorFlow:如何使用优化器和 import_meta_graph 重新开始训练?【英文标题】:Python TensorFlow: How to restart training with optimizer and import_meta_graph? 【发布时间】:2017-08-31 19:06:36 【问题描述】:

我正在尝试通过从中断的地方重新开始在 TensorFlow 中进行模型训练。我想使用最近添加的(我认为是 0.12+)import_meta_graph(),以免重建图形。

我已经看到了解决方案,例如Tensorflow: How to save/restore a model?,但我遇到了 AdamOptimizer 的问题,特别是我收到了 ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used 错误。 This can be fixed by initializing,但是我的模型值被清除了!

还有其他答案和一些完整示例,但它们似乎总是较旧,因此不包括较新的 import_meta_graph() 方法,或者没有非张量优化器。我能找到的最接近的问题是tensorflow: saving and restoring session,但没有最终明确的解决方案,而且示例非常复杂。

理想情况下,我想要一个简单的可运行示例,从头开始,停止,然后重新开始。我有一些有用的东西(如下),但也想知道我是否遗漏了一些东西。肯定不止我一个人这样做吗?

【问题讨论】:

我在使用 AdamOptimizer 时遇到了同样的问题。通过将我的操作放入集合中,我设法让事情发挥作用。这个例子对我帮助很大:seaandsailor.com/tensorflow-checkpointing.html 【参考方案1】:

这是我通过阅读文档、其他类似解决方案以及反复试验得出的结论。这是一个简单的随机数据自动编码器。如果运行,然后再次运行,它将从停止的地方继续(即第一次运行的成本函数从 ~0.5 -> 0.3 第二次运行开始 ~0.3)。除非我遗漏了什么,否则所有的保存、构造函数、模型构建、add_to_collection 都是需要的,并且按照精确的顺序,但可能有更简单的方法。

是的,这里并不真正需要使用import_meta_graph 加载图表,因为代码就在上面,但这是我在实际应用程序中想要的。

from __future__ import print_function
import tensorflow as tf
import os
import math
import numpy as np

output_dir = "/root/Data/temp"
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")

input_length = 10
encoded_length = 3
learning_rate = 0.001
n_epochs = 10
n_batches = 10
if not os.path.exists(model_checkpoint_file_base + ".meta"):
    print("Making new")
    brand_new = True

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in")
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length],
                                          -1.0 / math.sqrt(input_length),
                                          1.0 / math.sqrt(input_length)), name="W_enc")
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc")
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded")
    W_dec = tf.transpose(W_enc, name="W_dec")
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec")
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded")
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")

    saver = tf.train.Saver()
else:
    print("Reloading existing")
    brand_new = False
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")
    g = tf.get_default_graph()
    x_in = g.get_tensor_by_name("x_in:0")
    cost = g.get_tensor_by_name("cost:0")


sess = tf.Session()
if brand_new:
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    init = tf.global_variables_initializer()
    sess.run(init)
    tf.add_to_collection("optimizer", optimizer)
else:
    saver.restore(sess, model_checkpoint_file_base)
    optimizer = tf.get_collection("optimizer")[0]

for epoch_i in range(n_epochs):
    for batch in range(n_batches):
        batch = np.random.rand(50, input_length)
        _, curr_cost = sess.run([optimizer, cost], feed_dict=x_in: batch)
        print("batch_cost:", curr_cost)
        save_path = tf.train.Saver().save(sess, model_checkpoint_file_base)

【讨论】:

【参考方案2】:

我遇到了同样的问题,我只是弄清楚了哪里出了问题,至少在我的代码中。

最后我在saver.restore()中使用了错误的文件名。该函数必须指定不带文件扩展名的文件名,就像saver.save() 函数一样:

saver.restore(sess, 'model-1')

而不是

saver.restore(sess, 'model-1.data-00000-of-00001')

有了这个,我会做你想做的事:从头开始,停止,然后重新开始。我不需要使用tf.train.import_meta_graph() 函数从元文件初始化第二个保护程序,也不需要在初始化优化器后显式声明tf.initialize_all_variables()

我的完整模型还原如下所示:

with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'model-1')

我认为在协议 V1 中,您仍然必须将.ckpt 添加到文件名中,而对于import_meta_graph(),您仍然需要添加.meta,这可能会引起用户的一些混淆。也许应该在文档中更明确地指出这一点。

【讨论】:

【参考方案3】:

在恢复会话中创建保护程序对象时可能会出现问题。

在恢复会话中使用以下代码时,我遇到了与您相同的错误。

saver = tf.train.import_meta_graph('tmp/hsmodel.meta')
saver.restore(sess, tf.train.latest_checkpoint('tmp/'))

但是当我这样改变时,

saver = tf.train.Saver()
saver.restore(sess, "tmp/hsmodel")

错误已消失。 “tmp/hsmodel”是我在保存会话中提供给 saver.save(sess,"tmp/hsmodel") 的路径。

这里有一个关于存储和恢复训练 MNIST 网络(包含 Adam 优化器)会话的简单示例。这有助于我与我的代码进行比较并解决问题。

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

【讨论】:

【参考方案4】:

saver 类允许我们通过以下方式保存会话: saver.save(sess, "checkpoints.ckpt")

并允许我们恢复会话: saver.restore(sess, tf.train.latest_checkpoint("checkpoints.ckpt"))

【讨论】:

以上是关于Python TensorFlow:如何使用优化器和 import_meta_graph 重新开始训练?的主要内容,如果未能解决你的问题,请参考以下文章

如何选择优化器 optimizer

在 Tensorflow 中使用 Adadelta 优化器时出现未初始化值错误

如何在 tensorflow 2.0.0 中使用 Lazy Adam 优化器

如何使用 Tensorflow 2 在我的自定义优化器上更新可训练变量

如何在 Tensorflow 中创建优化器

Tensorflow Adam 优化器与 Keras Adam 优化器