小白学习tensorflow教程四使用 tfhub中的模型EfficientDet-Lite2 进行对象检测

Posted 刘润森!

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了小白学习tensorflow教程四使用 tfhub中的模型EfficientDet-Lite2 进行对象检测相关的知识,希望对你有一定的参考价值。

@Author:Runsen

tfhub是tensorflow官方提供训练好的模型的一个仓库。今天,我使用 tfhub中的模型EfficientDet-Lite2 进行对象检测

选择的模型是EfficientDet-Lite2 对象检测模型。它在具有 91 个不同标签的 COCO17 数据集上进行了训练,并针对 TFLite 应用程序进行了优化。

EfficientDet-Lite 是一系列对移动/物联网友好的对象检测模型。

目前暂时无法直接正常访问 https://tfhub.dev,可以通过镜像 https://hub.tensorflow.google.cn

加载TensorFlow 模型

图片输入:大小可变的三通道图像。输入张量是一个 tf.uint8具有形状张量[None, height, width, 3]。


加载图像并将其处理为 TensorFlow 模型的tensor格式。

import tensorflow_hub as hub
import cv2
import numpy
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

width = 1028
height = 1028

img = cv2.imread('image.jpg')
inp = cv2.resize(img, (width , height ))
rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)
rgb_tensor = tf.convert_to_tensor(rgb, dtype=tf.uint8)
rgb_tensor = tf.expand_dims(rgb_tensor , 0)

加载EfficientDet-Lite2,官方给出加载代码

由于下载模型网速不行,将开发环境转到colab中。

detector = hub.load("https://hub.tensorflow.google.cn/tensorflow/efficientdet/lite2/detection/1")
boxes, scores, classes, num_detections = detector(boxes, scores, classes, num_detections = detector(rgb_tensor))

输出字典包含:

  • detection_boxes:一个tf.float32形状的张量[N, 4]含有以下列顺序边界框坐标:[ymin, xmin, ymax, xmax]。
  • detection_scores:包含检测分数tf.float32的形状张量[N]。
  • detection_classes:包含标签文件中检测类索引tf.int的形状张量[N]。
  • num_detections:一个tf.int只有一个值的张量,检测次数[N]。

下面是检测的scores和class


物体编号:47,72,77检测的置信度比较高。

label = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
       'train', 'truck', 'boat', 'traffic light', 'fire hydrant', '-',
       'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
       'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
       '-', 'backpack', 'umbrella', '-', '-', 'handbag', 'tie',
       'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
       'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
       'tennis racket', 'bottle', '-', 'wine glass', 'cup', 'fork',
       'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
       'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
       'couch', 'potted plant', 'bed', '-', 'dining table', '-', '-',
       'toilet', '-', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
       'cell phone', 'microwave', 'oven', 'toaster', 'sink',
       'refrigerator', '-', 'book', 'clock', 'vase', 'scissors',
       'teddy bear', 'hair drier', 'toothbrush', '-']

下面将tensor转换为numpy()

pred_labels = classes.numpy().astype('int')[0] 
pred_labels = [labels[i] for i in pred_labels]
pred_boxes = boxes.numpy()[0].astype('int')
pred_scores = scores.numpy()[0]

最后设置一个阈值,将检测框加到图片上。

for score, (ymin,xmin,ymax,xmax), label in zip(pred_scores, pred_boxes, pred_labels):
        if score < 0.5:
            continue
        score_txt = f'{100 * round(score)}%'
        img_boxes = cv2.rectangle(rgb,(xmin, ymax),(xmax, ymin),(0,255,0),2)      
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(img_boxes, label,(xmin, ymax-10), font, 0.5, (255,0,0), 2, cv2.LINE_AA)
        cv2.putText(img_boxes,score_txt,(xmax, ymax-10), font, 0.5, (255,0,0), 2, cv2.LINE_AA)
plt.figure(figsize=(10,10))
plt.imshow(img_boxes)
plt.savefig('image_pred.jpg',transparent=True, )

以上是关于小白学习tensorflow教程四使用 tfhub中的模型EfficientDet-Lite2 进行对象检测的主要内容,如果未能解决你的问题,请参考以下文章

将 TFHUB 中的 TensorFlow 模型加载到 BigQuery 中

小白学习tensorflow教程一tensorflow基本操作快速构建线性回归和分类模型

小白学习tensorflow教程一tensorflow基本操作快速构建线性回归和分类模型

小白学习tensorflow教程二TensorBoard可视化模型训练

小白学习tensorflow教程三TF2新特性@tf.function和AutoGraph

小白入门深度学习 | 第六篇:TensorFlow2 回调极速入门