如何保存表示在 Tensorflow 中构建的神经网络的对象

Posted

技术标签:

【中文标题】如何保存表示在 Tensorflow 中构建的神经网络的对象【英文标题】:How to save object representing a neural network constructed in Tensorflow 【发布时间】:2019-06-06 02:22:27 【问题描述】:

我是 Tensorflow 的新手,正在 github 上玩一些代码。此代码为神经网络创建了一个类,其中包括构建网络、制定损失函数、训练网络、执行预测等方法。

骨架代码如下所示:

class NeuralNetwork:
    def __init__(...):

    def initializeNN():

    def trainNN():

    def predictNN():

等等。神经网络是用 Tensorflow 构建的,因此类定义及其方法使用 tensorflow 语法。

现在在我的脚本的主要部分,我通过

创建这个类的一个实例
model = NeuralNetwork(...)

并使用model.predict等模型的方法来生成绘图。

由于训练神经网络需要很长时间,我想保存对象“模型”以供将来使用,并有可能调用其方法。我试过泡菜和莳萝,但都失败了。对于泡菜,我得到了错误:

TypeError: can't pickle _thread.RLock objects

而对于莳萝,我得到了:

TypeError:无法腌制 SwigPyObject 对象

有什么建议可以保存对象并仍然能够调用它的方法吗?这是必不可少的,因为我可能希望在未来对一组不同的点进行预测。

谢谢!

【问题讨论】:

您是否尝试过使用tf.train.saver? Guide 这如何适合我上面的代码?你介意给我举个例子吗?我不明白这如何使我能够访问我的对象的方法,例如执行预测的方法。我想保存上面的“模型”对象本身 Tensorflow: how to save/restore a model?的可能重复 【参考方案1】:

你应该做的是:

# Build the graph
model = NeuralNetwork(...)
# Create a train saver/loader object
saver = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Train the model in the same way you are doing it currently
    model.train_model()
    # Once you are done training, just save the model definition and it's learned weights
    saver.save(sess, save_path)

而且,你已经完成了。那么当你想再次使用模型时,你可以做的是:

# Build the graph
model = NeuralNetwork()
# Create a train saver/loader object
loader = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Load the model variables
    loader.restore(sess, save_path)
    # Train the model again for example
    model.train_model()

【讨论】:

感谢您的帮助,戈尔扬。但是有两个澄清:我需要使用扩展名“ckpt”对吗?另外,您的示例是否意味着我必须再次训练神经网络?我试过了,python 抱怨“从检查点恢复失败” 不需要使用ckpt扩展。你也不需要再次训练你的模型。这就是稍后保存和恢复它的意义所在。我的观点是 model.train_model() 示例,您可以在恢复模型后对模型执行任何操作。

以上是关于如何保存表示在 Tensorflow 中构建的神经网络的对象的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 中的连体神经网络

不怕学不会 | 使用TensorFlow从零开始构建卷积神经网络

Tensorflow保存神经网络参数有妙招:Saver和Restore

如何保存Tensorflow中的Tensor参数,保存训练中的中间参数,存储卷积层的数据

增强拓扑(NEAT)神经网络的神经进化可以在 TensorFlow 中构建吗?

如何在 TensorFlow 中恢复多个神经网络模型?