Tensorflow加载多个模型方法实践——Graph与Session

Posted 肖永威

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow加载多个模型方法实践——Graph与Session相关的知识,希望对你有一定的参考价值。


Tensorflow(1.x版本)是一种符号式编程框架,首先要构造一个图(graph),然后在这个图上做运算,也就是用计算图来构建网络,用会话(Session)来具体执行网络。

计算框架,是通过定义placeholder、Variable和OP等构成一张完成计算图Graph;接下来通过新建Session实例启动模型运行,Session实例会分布式执行Graph,输入数据,根据优化算法更新Variable,然后返回执行结果即Tensor实例。

  • 计算图graph定义了计算过程与公式,是一些加减乘除等数学运算的组合。它本身不会进行任何计算,也不保存任何中间计算结果。

  • session用来运行一个graph,或者运行graph的一部分。它类似于一个执行者,给graph灌入输入数据,得到输出,并保存中间的计算结果。同时它也给graph分配计算资源(如内存、显卡等)。

  • 一个graph可以供多个session使用,而一个session不一定需要使用graph的全部,可以只使用其中的一部分。

通常,使用上下文管理器的代码结果如下,其中使用默认graph。

        with tf.Session() as sess:
            saver = tf.train.import_meta_graph(DB_info.BPNetModel_graph)
            saver.restore(sess,tf.train.latest_checkpoint(DB_info.BPNetModel))
            graph = tf.get_default_graph()
            x = graph.get_tensor_by_name("x:0")
            # 输出预测结果
            y_conv = graph.get_tensor_by_name('y_conv:0')
            keep_prob = graph.get_tensor_by_name("keep_prob:0")    
            ret = sess.run(y_conv, feed_dict=x:dtest,keep_prob:1.0)
            y = sess.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率

Tensorflow加载多个模型方法是在Tensorflow中创建多个Session,每个Session运行一个graph,实践案例代码如下。

    def ChurnModelWorking(self):
        graph = tf.Graph()                                                      # 定义图1
        with tf.Session(graph = graph) as sess:                                 # Session加载所定义的图
            saver = tf.train.import_meta_graph(DB_info.BPNetModel_graph)        # 加载模型图
            saver.restore(sess,tf.train.latest_checkpoint(DB_info.BPNetModel))  # 恢复模型参数
            
            x = graph.get_tensor_by_name("x:0")                                 # 从图中获取输入定义
            # 输出预测结果
            y_conv = graph.get_tensor_by_name('y_conv:0')                       # 从图中获取输出定义
            keep_prob = graph.get_tensor_by_name("keep_prob:0")    
            ret = sess.run(y_conv, feed_dict=x:dtest,keep_prob:1.0)
            y = sess.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率

    def ChurnOtherModelWorking(self):
        graph_other = tf.Graph()                                                # 定义另一个图2
        with tf.Session(graph=graph_other) as sess_other:                       # Session加载所定义的图2
            saver_other = tf.train.import_meta_graph(DB_info.BPNetModelOther_graph)
            saver_other.restore(sess_other,tf.train.latest_checkpoint(DB_info.BPNetModelOther))            
            x = graph_other.get_tensor_by_name("x:0")
            # 输出预测结果
            y_conv = graph_other.get_tensor_by_name('y_conv:0')
            keep_prob = graph_other.get_tensor_by_name("keep_prob:0")    
            ret = sess_other.run(y_conv, feed_dict=x:dtest,keep_prob:1.0)
            y = sess_other.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率  

实践总结
通过在Tensorflow中创建多个Session,每个Session运行一个graph,来实现加载多个Tensorflow模型进行组合应用。

关于Tensorflow的图
graph视角的关系图

TensorFlow是一个通过计算图的形式来表述计算的编程系统。其中的Tnesor,代表它的数据结构,而Flow代表它的计算模型。TensorFlow中的每一个计算都是计算图上的一个节点,而节点之间的线描述了计算之间的依赖关系。

在TensorFlow程序中,系统会自动维护一个默认的计算图,通过tf.get_default_gragh函数可以获取当前默认的计算图。除了默认的计算图,TensorFlow也支持通过tf.Graph函数来生成新的计算图。不同的计算图上的张量和运算不会共享。

计算图举例如下:

参考:

[1]. Echo. ​Tensorflow Session使用和浅析. 知乎. 2020.03
[2]. Arkenstone. Tensorflow同时加载使用多个模型. 博客园. 2017.06
[3]. 马尔代夫Maldives. tensorflow中的Graph(图)和Session(会话)的关系. 简书. 2019.07
[4]. 老夫叨叨叨. tensorflow: graph. 简书. 2019.01
[5]. HOU_JUN. TensorFlow计算模型—计算图. 博客园. 2018.04

以上是关于Tensorflow加载多个模型方法实践——Graph与Session的主要内容,如果未能解决你的问题,请参考以下文章

如何将多个 GPU 用于协同工作的多个模型?

在 TensorFlow Functional API 中保存和加载具有相同图形的多个模型

81TensorFlow 2 模型部署方法实践--TensorFlow Serving 部署模型

基于Tensorflow2.x低阶API搭建神经网络模型并训练及解决梯度爆炸与消失方法实践

基于Tensorflow2.x低阶API搭建神经网络模型并训练及解决梯度爆炸与消失方法实践

基于Tensorflow2.x低阶API搭建神经网络模型并训练及解决梯度爆炸与消失方法实践