如何保存/加载 tensorflow contrib.learn 回归器?
Posted
技术标签:
【中文标题】如何保存/加载 tensorflow contrib.learn 回归器?【英文标题】:How do I save/load a tensorflow contrib.learn regressor? 【发布时间】:2017-03-10 21:26:14 【问题描述】:我有一个 tensorflow contrib.learn.DNNRegressor 作为以下代码 sn-p 的一部分进行了训练:
regressor = tf.contrib.learn.DNNRegressor(feature_columns=fc,
hidden_units=hu_array,
optimizer=tf.train.AdamOptimizer(
learning_rate=0.001,
),
enable_centered_bias=False,
activation_fn=tf.tanh,
model_dir="./models/my_model/",
)
regressor.fit(x=training_features, y=training_labels, steps=10000)
经过训练的网络性能非常好,我想在另一台机器上将它用作其他代码的一部分。我尝试复制 models/my_model 目录,并构建一个新的 DNNRegressor 指向 model_dir,但它需要我提供 feature_columns 和 hidden_units 定义。这些信息不应该通过存储在 model_dir 中的快照获得吗?有没有更好的方法来保存/恢复性能良好的训练模型,用作预测器,而无需单独保存 feature_columns 和 hidden_units?
【问题讨论】:
【参考方案1】:我想出了一些可行的方法——并不理想,但它可以完成工作。如果有人有更好的主意,我会全力以赴。
我将 DNNRegressor 的 kwargs 转换为字典,并使用了 ** 运算符。然后我能够腌制 kwargs 字典,并从中重建 DNNRegressor。例如:
reg_args = 'feature_columns': fc, 'hidden_units': hu_array, ...
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
pickle.dump(reg_args, open('reg_args.pkl', 'wb'))
稍后,我通过以下方式重构:
reg_args = pickle.load(open('reg_args.pkl', 'rb'))
# On another machine and so my model dir path changed:
reg_args['model_dir'] = NEW_MODEL_DIR
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
效果很好。我确信一定有更好的方法,但现在如果有人试图找出 tf.contrib.learn 的解决方法,这是一个解决方案。
【讨论】:
【参考方案2】:训练时
您调用DNNRegressor(..., model_dir)
,然后调用fit()
和evaluate()
方法。
测试时
您调用DNNRegressor(..., model_dir)
,然后可以调用predict()
方法。 您的模型将在 model_dir
中找到经过训练的模型,并将加载经过训练的模型参数。
参考
Issue #3340 of TF
【讨论】:
你能举个具体的例子吗?我正是这样做的,它一直给我错误。请帮忙举个例子,否则根本没用。以上是关于如何保存/加载 tensorflow contrib.learn 回归器?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 TensorFlow 2 中保存/加载模型的一部分?