如何使用 tensorflow 保存 DNN 模型

Posted

技术标签:

【中文标题】如何使用 tensorflow 保存 DNN 模型【英文标题】:how to save a DNN model with tensorflow [duplicate] 【发布时间】:2017-11-09 00:36:53 【问题描述】:

我有训练 DNN 网络的代码。我不想每次都训练这个网络,因为它占用了太多时间。如何保存模型?

def train_model(filename, validation_ratio=0.):
    # define model to be trained
    columns = [tf.contrib.layers.real_valued_column(str(col),
                                                    dtype=tf.int8)
               for col in FEATURE_COLS]
    classifier = tf.contrib.learn.DNNClassifier(
        feature_columns=columns,
        hidden_units=[100, 100],
        n_classes=N_LABELS,
        dropout=0.3)

    # load and split data
    print( 'Loading training data.')
    data = load_batch(filename)
    overall_size = data.shape[0]
    learn_size = int(overall_size * (1 - validation_ratio))
    learn, validation = np.array_split(data, [learn_size])
    print( 'Finished loading data. Samples count = '.format(overall_size))

    # learning
    print( 'Training using batch of size '.format(learn_size))
    classifier.fit(input_fn=lambda: pipeline(learn),
                   steps=learn_size)

    if validation_ratio > 0:
        validate_model(classifier, learn, validation)

    return classifier

运行这个函数后,我得到了一个我想保存的DNNClassifier

【问题讨论】:

你得到答案了吗?可以分享一下解决方法吗? 【参考方案1】:

我相信这里已经回答了这个问题:Tensorflow: how to save/restore a model?

saver = tf.train.Saver()
saver.save(sess, 'my_test_model',global_step=1000)

(从该问题的答案复制的代码)

【讨论】:

以上是关于如何使用 tensorflow 保存 DNN 模型的主要内容,如果未能解决你的问题,请参考以下文章

HCIA-AI_深度学习_利用TensorFlow进行手写数字识别

HCIA-AI_深度学习_利用TensorFlow进行手写数字识别

HCIA-AI_深度学习_利用TensorFlow进行手写数字识别

OpenCV DNN模块——从TensorFlow模型导出到OpenCV部署详解

将 TensorFlow Frozen Inference Graph 加载到 OpenCV DNN 时出错

如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它