在 tensorflow 中使用两种不同的模型

Posted

技术标签:

【中文标题】在 tensorflow 中使用两种不同的模型【英文标题】:Using two different models in tensorflow 【发布时间】:2018-01-26 14:47:10 【问题描述】:

我正在尝试使用两种不同的 mobilenet 模型。以下是我如何初始化模型的代码。

def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='')

    print ('Took  seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

由于两者都是两个不同的模型,我该如何使用它进行预测?

更新

initialSetup()

age_session = tf.Session(graph=age_graph_def)
gender_session = tf.Session(graph=gender_graph_def)

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0')

    print ('Took  seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()

错误

Traceback(最近一次调用最后一次):文件 “C:/Users/Desktop/untitled/testimg/testimg/combo.py”,第 48 行,在 age_session = tf.Session(graph=age_graph_def) 文件 "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", 第 1292 行,在 init 中 super(Session, self).init(target, graph, config=config) File "C:\Program 文件\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", 第 529 行,在 init 中 raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception ignored in: > Traceback(最近一次调用最后一次):文件 "C:\程序 文件\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", 第 587 行,在 del 如果 self._session 不是 None: AttributeError: 'Session' object has no attribute '_session'

【问题讨论】:

您是否成功使用过这种方式加载的单个模型?通常的方法是将不同的非空 name 参数传递给每个 tf.import_graph_def() 调用,然后使用这些名称作为您要提供和获取的每个模型中特定张量的前缀。 是的,它单独工作。如果我添加名称,它会说,不存在这样的张量 您能否添加您用于调用会话的代码以及打印的完整错误?如果将name 添加到导入的图表中,则需要在该图表中使用的任何张量名称前加上name 的值,后跟/ 请检查更新的问题... 【参考方案1】:

当您在同一个图中使用多个模型时,使用名称范围为各个张量提供可预测的名称。例如,您可以将initial_setup() 改写如下:

def initialSetup():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    start_time = timeit.default_timer()

    # This takes 2-5 seconds to run
    # Unpersists graph from file
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        tf.import_graph_def(age_graph_def, name='age_model')

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
        gender_graph_def = tf.GraphDef()
        gender_graph_def.ParseFromString(f.read())
        tf.import_graph_def(gender_graph_def, name='gender_model')

    print ('Took  seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

现在age_graph_def 中的所有节点的名称都将以"age_model/" 为前缀,gender_graph_def 中的所有节点的名称都将以"gender_model/" 为前缀。它们都是同一个默认图的一部分,因此您可以使用单个 tf.Session 而不使用 graph 参数来访问任一模型。

initialSetup()

with tf.Session() as sess:
    start_time = timeit.default_timer()

    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0')

    # Alternatively, to get a tensor from the gender model:
    # tensor = sess.graph.get_tensor_by_name('gender_model/...')

    print ('Took  seconds to feed data to graph'.format(timeit.default_timer() - start_time))

    while True:
        # Capture frame-by-frame
        ret, frame = video_capture.read()

【讨论】:

谢谢它的工作..但是框架现在有点滞后......有什么方法可以提高它的速度吗? 如果我将两个模型中的两个类合并并重新训练,会影响准确性吗? @mrry 也适用于我,但速度是一个主要问题!有什么建议吗?谢谢!【参考方案2】:

tf.Session 需要 tf.Graph 实例而不是 tf.GraphDef,请按照以下步骤解决问题。

def initialSetup():
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
        age_graph_def = tf.GraphDef()
        age_graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(age_graph_def, name='')
            age_graph = graph

   ...
   return age_graph, gender_graph

age_graph, gender_graph = initial_setup() 
age_session = tf.Session(graph=age_graph)
...
# also delete the following line, as it creates another new context 
with tf.Session() as sess:

【讨论】:

现在它说,KeyError: "The name 'final_result:0' refers to a Tensor which does not exist. The operation, 'final_result', does not exist in the graph." 它什么也不打印:( 问题在于这一行with tf.Session() as sess: 你不应该在这里定义另一个会话。因为age_session 已经定义了。 我尝试使用 age_session 作为 sess: 它说 KeyError: "The name 'final_result:0' refers to a Tensor which does not exist. The operation, 'final_result', does not exist in the graph." tf.global_variables() 正在打印空列表

以上是关于在 tensorflow 中使用两种不同的模型的主要内容,如果未能解决你的问题,请参考以下文章

在 Tensorflow 中,将 Google 的 BigTransfer 模型转换为 Tensorflow Lite 时出现错误

Keras,Tensorflow:将两个不同的模型输出合并为一个

在不同的会话中加载 TensorFlow 模型

tensorflow的断点续训

如何使用 TensorFlow 连接两个具有不同形状的张量?

干货使用TensorFlow官方Java API调用TensorFlow模型(附代码)