tensorflow ckpt模型转saved_model格式并进行模型预测

Posted 修炼之路

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow ckpt模型转saved_model格式并进行模型预测相关的知识,希望对你有一定的参考价值。

导读

tensorflow的checkpoint模型文件,只包含了模型的参数并不包含模型结构,为了方便使用tensorflow的serving进行部署,我们需要将checkpoint模型转换为saved_model格式

转换代码如下

def ckpt_to_pb(ckpt_path,output_pd_path):
"""
ckpt_path:checkpoint模型文件的目录
output_pd_path:savedmodel模型文件保存的目录
"""
	#加载模型的参数文件
    experiment_folder = "/tmp/"
    config = json.load(open(experiment_folder + 'config.json'))
	#根据的模型参数文件获取模型的结构(输入和输出)
    [x, y_, is_train, y, normalized_y, cost] = train.tf_define_model_and_cost(config)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
		#定义模型的输入输出节点
        SignatureDef = sm.signature_def_utils.build_signature_def(
            inputs=
                "x_input": sm.utils.build_tensor_info(x),
                "is_train": sm.utils.build_tensor_info(is_train)
            ,
            outputs=
                "y_sigmoid": sm.utils.build_tensor_info(normalized_y)
            ,
            method_name=sm.signature_constants.PREDICT_METHOD_NAME,
        )
		#加载checkpoint模型参数	
        loader = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta")
        loader.restore(sess,ckpt_path)
		
		#将checkpoint模型转换为savedmodel
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_pd_path)
        builder.add_meta_graph_and_variables(sess,tags = [tf.compat.v1.saved_model.tag_constants.SERVING],
                                             signature_def_map=sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: SignatureDef,
                                             strip_default_attrs=True)

        builder.save()

加载savedmodel模型进行预测

import tensorflow as tf

export_dir = "/save_model"
#加载savedmodel模型
imported = tf.saved_model.load(export_dir)
model = imported.signatures["serving_default"]
#模型预测
pred = model(x_input=tf.convert_to_tensor(input_array), is_train=tf.constant(False))
#获取模型的预测结果
pred = pred["y_sigmoid"].numpy()

以上是关于tensorflow ckpt模型转saved_model格式并进行模型预测的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow ckpt模型转saved_model格式并进行模型预测

tensorflow :ckpt模型转换为pytorch : hdf5模型

DL之GRU(Tensorflow框架):基于茅台股票数据集利用GRU算法实现回归预测(保存模型.ckpt.index.ckpt.data文件)

何时在 Tensorflow 模型保存中使用 .ckpt、.hdf5 和 .pb 文件扩展名?

TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

tensorflow和pytorch模型之间转换