提高 Tensorflow 图像分类的速度

Posted

技术标签:

【中文标题】提高 Tensorflow 图像分类的速度【英文标题】:Increase the speed for image classification of Tensorflow 【发布时间】:2017-10-15 02:02:00 【问题描述】:

使用来自 Link 的说明,我重新训练了新类别的 tensorflow inception。

但我注意到,随后如果我想对一组图像进行分类,它会逐个遍历图像然后对其进行分类。如果数据集很大,完成分类需要很长时间。例如,1000 张图片需要 45 分钟。

对于图像分类,我使用了online 可用的 LabelImage.py,如下所示:

import tensorflow as tf
import sys

image_path = sys.argv[1] #Pass the test file as argument

# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()

# Loads label file (the retained labels from retraining) and strips off carriage return
label_lines = [line.rstrip() for line
                   in tf.gfile.GFile("/tf_files/tf_files/retrained_labels.txt")]

# Unpersists graph from file
with tf.gfile.FastGFile("/tf_files/tf_files/retrained_graph.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    # Feed the image_data as input to the graph and get first prediction i.e. the most likely result
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

    predictions = sess.run(softmax_tensor, \
             'DecodeJpeg/contents:0': image_data)

    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

    for node_id in top_k:
        human_string = label_lines[node_id]
        score = predictions[0][node_id]
        print('%s (score = %.5f)' % (human_string, score))

如您所见,它会逐张处理图像。

是否可以加快进程?当我重新训练库时,它不是为多个 GPU 编译的。有没有其他方法可以加快分类过程?

【问题讨论】:

可能你的网络是在[None, ...] 输入上训练的,这意味着它可能接受任意数量的输入。您应该能够通过创建一堆图像并将它们提供给图层来做到这一点。您必须将predictions[0].argsort()[...] 更改为predictions.argsort(axis=1)[: , ...]。但简而言之,这类问题不适合 SO,而且有点离题。你应该能够用上面的 cmets 想出一个解决方案。 【参考方案1】:

方法是加载库并随后处理图像。因此将节省每个图像的加载部分。

找到答案here。

【讨论】:

【参考方案2】:

为了完善 aandroidtest 给出的答案,如果您使用上述脚本,您将花费大部分时间为每个图像重新加载模型等。

相反,您应该只加载一次模型,然后在脚本中一张一张地检查图像。

【讨论】:

以上是关于提高 Tensorflow 图像分类的速度的主要内容,如果未能解决你的问题,请参考以下文章

基于Tensorflow + Opencv 实现CNN自定义图像分类

基于Tensorflow + Opencv 实现CNN自定义图像分类

使用 Tensorflow 识别错误分类的图像

干货快速上手图像识别:用TensorFlow API实现图像分类实例

使用 TensorFlow CNN 进行图像分类

在 TensorFlow 图像分类中获取标签