如何保存/加载 tensorflow contrib.learn 回归器?

Posted

技术标签:

【中文标题】如何保存/加载 tensorflow contrib.learn 回归器?【英文标题】:How do I save/load a tensorflow contrib.learn regressor? 【发布时间】:2017-03-10 21:26:14 【问题描述】:

我有一个 tensorflow contrib.learn.DNNRegressor 作为以下代码 sn-p 的一部分进行了训练:

regressor = tf.contrib.learn.DNNRegressor(feature_columns=fc, 
                                          hidden_units=hu_array, 
                                          optimizer=tf.train.AdamOptimizer(
                                                       learning_rate=0.001,
                                                    ),
                                          enable_centered_bias=False,
                                          activation_fn=tf.tanh,
                                          model_dir="./models/my_model/",
                                          )

regressor.fit(x=training_features, y=training_labels, steps=10000)

经过训练的网络性能非常好,我想在另一台机器上将它用作其他代码的一部分。我尝试复制 models/my_model 目录,并构建一个新的 DNNRegressor 指向 model_dir,但它需要我提供 feature_columns 和 hidden_​​units 定义。这些信息不应该通过存储在 model_dir 中的快照获得吗?有没有更好的方法来保存/恢复性能良好的训练模型,用作预测器,而无需单独保存 feature_columns 和 hidden_​​units?

【问题讨论】:

【参考方案1】:

我想出了一些可行的方法——并不理想,但它可以完成工作。如果有人有更好的主意,我会全力以赴。

我将 DNNRegressor 的 kwargs 转换为字典,并使用了 ** 运算符。然后我能够腌制 kwargs 字典,并从中重建 DNNRegressor。例如:

reg_args = 'feature_columns': fc, 'hidden_units': hu_array, ...
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
pickle.dump(reg_args, open('reg_args.pkl', 'wb'))

稍后,我通过以下方式重构:

reg_args = pickle.load(open('reg_args.pkl', 'rb'))
# On another machine and so my model dir path changed:
reg_args['model_dir'] = NEW_MODEL_DIR
regressor = tf.contrib.learn.DNNRegressor(**reg_args)

效果很好。我确信一定有更好的方法,但现在如果有人试图找出 tf.contrib.learn 的解决方法,这是一个解决方案。

【讨论】:

【参考方案2】:

训练时

您调用DNNRegressor(..., model_dir),然后调用fit()evaluate() 方法。

测试时

您调用DNNRegressor(..., model_dir),然后可以调用predict() 方法。 您的模型将在 model_dir 中找到经过训练的模型,并将加载经过训练的模型参数。

参考

Issue #3340 of TF

【讨论】:

你能举个具体的例子吗?我正是这样做的,它一直给我错误。请帮忙举个例子,否则根本没用。

以上是关于如何保存/加载 tensorflow contrib.learn 回归器?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 TensorFlow 2 中保存/加载模型的一部分?

TensorFlow 从文件中保存/加载图形

转 tensorflow模型保存 与 加载

如何保存经过训练的模型(估计器)并将其加载回来以使用 Tensorflow 中的数据对其进行测试?

如何使用 c++ 在 tensorflow 中保存模型

TensorFlow 保存和加载不一致