如何从检查点使用 tf.estimator.Estimator 进行预测?
Posted
技术标签:
【中文标题】如何从检查点使用 tf.estimator.Estimator 进行预测?【英文标题】:How to make predictions with tf.estimator.Estimator from checkpoint? 【发布时间】:2018-06-03 07:30:38 【问题描述】:我刚刚训练了一个 CNN 来使用 tensorflow 识别太阳黑子。我的模型与this 几乎相同。 问题是我在任何地方都找不到关于如何使用训练阶段生成的检查点进行预测的明确解释。
尝试使用标准恢复方法:
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')
但是我不知道如何运行它。
尝试像这样使用tf.estimator.Estimator.predict()
:
# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="./model")
# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = "probabilities": "softmax_tensor"
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)
# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
x="x": pred_data,
shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)
但它的作用是吐出<generator object Estimator.predict at 0x10dda6bf8>
。
虽然如果我使用相同的代码但使用 tf.estimator.Estimator.evaluate()
它就像一个魅力(重新加载模型,执行评估并将其发送到 TensorBoard)。
我知道有很多类似的问题,但我真的找不到适合我的方法。
【问题讨论】:
Ciao @RobiNoob,你在哪里使用logging_hook
?
【参考方案1】:
sunspot_classifier.predict(input_fn=pred_input_fn)
返回生成器。所以pred_results
是生成器对象。要从中获得价值,您需要通过 next(pred_results)
对其进行迭代
解决办法是
print(next(pred_results))
【讨论】:
以上是关于如何从检查点使用 tf.estimator.Estimator 进行预测?的主要内容,如果未能解决你的问题,请参考以下文章
如何从 C# 或 Javascript 检查文件是不是存在?