Tensorflow签名输出占位符
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow签名输出占位符相关的知识,希望对你有一定的参考价值。
我正在尝试导出Tensorflow模型,以便我可以在Tensorflow服务中使用它。这是我使用的脚本:
import os
import tensorflow as tf
trained_checkpoint_prefix = '/home/ubuntu/checkpoint'
export_dir = os.path.join('m', '0')
loaded_graph = tf.Graph()
config=tf.ConfigProto(allow_soft_placement=True)
with tf.Session(graph=loaded_graph, config=config) as sess:
# Restore from checkpoint
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + 'file.meta')
loader.restore(sess, tf.train.latest_checkpoint(trained_checkpoint_prefix))
# Create SavedModelBuilder class
# defines where the model will be exported
export_path_base = "/home/ubuntu/m"
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(0)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
batch_shape = (20, 256, 256, 3)
input_tensor = tf.placeholder(tf.float32, shape=batch_shape, name="X_content")
predictions_tf = tf.placeholder(tf.float32, shape=batch_shape, name='Y_output')
tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
tensor_info_output = tf.saved_model.utils.build_tensor_info(predictions_tf)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'image': tensor_info_input},
outputs={'output': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'style_image':
prediction_signature,
})
builder.save(as_text=True)
主要问题是输出签名(predictions_tf)。在这种情况下,将它设置为占位符,我得到一个错误,说明在从gRPC调用模型时必须设置它的值。我应该用什么呢?
我试过了
predictions_tf = tf.Variable(0, dtype=tf.float32, name="Y_output")
和
predictions_tf = tf.TensorInfo(dtype=tf.float32)
predictions_tf.name = "Y_output"
predictions_tf.dtype = tf.float32
答案
我可能会误解你想要做什么,但在这里你基本上创建了一个新的placeholder
输入和一个新的placeholder
输出。
我认为你应该做的是,一旦加载模型,你必须在变量input tensor
和prediction_tf
using中得到模型的输入和输出张量,例如
input_tensor=loaded_graph.get_tensor_by_name('the_name_in_the_loaded_graph:0')
prediction_tf=loaded_graph.get_tensor_by_name('the_pred_name_in_the_loaded_graph:0')
以上是关于Tensorflow签名输出占位符的主要内容,如果未能解决你的问题,请参考以下文章