tensorflow ckpt模型转saved_model格式并进行模型预测
Posted 修炼之路
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow ckpt模型转saved_model格式并进行模型预测相关的知识,希望对你有一定的参考价值。
导读
tensorflow的checkpoint
模型文件,只包含了模型的参数并不包含模型结构,为了方便使用tensorflow的serving
进行部署,我们需要将checkpoint模型转换为saved_model
格式
转换代码如下
def ckpt_to_pb(ckpt_path,output_pd_path):
"""
ckpt_path:checkpoint模型文件的目录
output_pd_path:savedmodel模型文件保存的目录
"""
#加载模型的参数文件
experiment_folder = "/tmp/"
config = json.load(open(experiment_folder + 'config.json'))
#根据的模型参数文件获取模型的结构(输入和输出)
[x, y_, is_train, y, normalized_y, cost] = train.tf_define_model_and_cost(config)
graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
#定义模型的输入输出节点
SignatureDef = sm.signature_def_utils.build_signature_def(
inputs=
"x_input": sm.utils.build_tensor_info(x),
"is_train": sm.utils.build_tensor_info(is_train)
,
outputs=
"y_sigmoid": sm.utils.build_tensor_info(normalized_y)
,
method_name=sm.signature_constants.PREDICT_METHOD_NAME,
)
#加载checkpoint模型参数
loader = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta")
loader.restore(sess,ckpt_path)
#将checkpoint模型转换为savedmodel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_pd_path)
builder.add_meta_graph_and_variables(sess,tags = [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map=sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: SignatureDef,
strip_default_attrs=True)
builder.save()
加载savedmodel模型进行预测
import tensorflow as tf
export_dir = "/save_model"
#加载savedmodel模型
imported = tf.saved_model.load(export_dir)
model = imported.signatures["serving_default"]
#模型预测
pred = model(x_input=tf.convert_to_tensor(input_array), is_train=tf.constant(False))
#获取模型的预测结果
pred = pred["y_sigmoid"].numpy()
以上是关于tensorflow ckpt模型转saved_model格式并进行模型预测的主要内容,如果未能解决你的问题,请参考以下文章
tensorflow ckpt模型转saved_model格式并进行模型预测
tensorflow :ckpt模型转换为pytorch : hdf5模型
DL之GRU(Tensorflow框架):基于茅台股票数据集利用GRU算法实现回归预测(保存模型.ckpt.index.ckpt.data文件)
何时在 Tensorflow 模型保存中使用 .ckpt、.hdf5 和 .pb 文件扩展名?