Demo 1基于object_detection API的行人检测 3:模型训练与测试

Posted mxiaoy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Demo 1基于object_detection API的行人检测 3:模型训练与测试相关的知识,希望对你有一定的参考价值。

训练准备

模型选择

选择ssd_mobilenet_v2_coco模型,下载地址(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md),解压到./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29.

修改object_detection配置文件

进入目录./Pedestrian_Detection/models/research/object_detection/samples/configs 找到对应的模型配置文件ssd_mobilenet_v2_coco.config修改配置文件。

根据提示信息:

技术图片

1、第9行,检测类别把90改为1,因为我们只检测行人,只有一个类别。

2、修改除提示外所有的技术图片

  2.1、第一个(156行)是我们所需模型的路径,即上一步下载好的:./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt

  2.2、第二个(175行)是train.record文件的路径,上一篇中我们准备好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_train.record

  2.3、第三个(177行)是上一篇准备好的label_map.pbtxt的路径:./Pedestrian_Detection/project/pedestrian_train/data/label_map.pbtxt

  2.4、第四个(189行)是eval.record文件的路径,上一篇中我们准备好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_eval.record

  2.5、第五个(191行)同2.3

这样config文件就修改完成了。然后把它放到:./Pedestrian_Detection/project/pedestrian_train/models目录下。最后在该目录下创建两个文件夹:train 和 eval,用于储存训练和验证的记录。

开始训练

打开命令行窗口

在research目录下输入:

(dl) D:\\Study\\dl\\Pedestrian_Detection\\models\\research>python object_detection/legacy/train.py --train_dir=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\train --pipeline_config_path=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\ssd_mobilenet_v2_coco.config --logtostderr

即可开始训练。

这里我们选择2000次之后,按ctrl+c结束训练。训练的详细信息可通过tensorboard来进行查看(这里不再赘述)。

查看我们的训练记录:

技术图片

导出模型

这里我们选择第2391次的训练数据来生成模型。

把下图4个文件放到:./Pedestrian_Detection/pedestrian_data/model  目录下

技术图片

 

 

 

 

在命令行窗口下输入命令:

(dl) D:\\Study\\dl\\Pedestrian_Detection\\models\\research>python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix=D:\\Study\\dl\\Pedestrian_Detection\\pedestrian_data\\model\\model.ckpt-2391 --output_directory=D:\\Study\\dl\\Pedestrian_Detection\\pedestrian_data\\test

查看发现对应的目录下已经生成了一系列的模型文件:

技术图片

 

测试模型

测试代码:

 1 import os
 2 import sys
 3 
 4 import cv2
 5 import numpy as np
 6 import tensorflow as tf
 7 
 8 sys.path.append("..")
 9 from object_detection.utils import label_map_util
10 from object_detection.utils import visualization_utils as vis_util
11 
12 ##################################################
13 
14 ##################################################
15 
16 # Path to frozen detection graph
17 PATH_TO_CKPT = D:/Study/dl/Pedestrian_Detection/pedestrian_data/test/frozen_inference_graph.pb
18 
19 # List of the strings that is used to add correct label for each box.
20 PATH_TO_LABELS = os.path.join(D:/Study/dl/Pedestrian_Detection/project/pedestrian_train/data, label_map.pbtxt)
21 
22 NUM_CLASSES = 1
23 detection_graph = tf.Graph()
24 with detection_graph.as_default():
25     od_graph_def = tf.GraphDef()
26     with tf.gfile.GFile(PATH_TO_CKPT, rb) as fid:
27         serialized_graph = fid.read()
28         od_graph_def.ParseFromString(serialized_graph)
29         tf.import_graph_def(od_graph_def, name=‘‘)
30 
31 label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
32 categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
33 category_index = label_map_util.create_category_index(categories)
34 
35 
36 def load_image_into_numpy_array(image):
37     (im_width, im_height) = image.size
38     return np.array(image.getdata()).reshape(
39       (im_height, im_width, 3)).astype(np.uint8)
40 
41 
42 with detection_graph.as_default():
43     with tf.Session(graph=detection_graph) as sess:
44         image_np = cv2.imread("D:/Study/dl/Pedestrian_Detection/project/test_images/3600.jpg")
45         # image_np = cv2.imread("D:/images/pedestrain.png")
46         cv2.imshow("input", image_np)
47         print(image_np.shape)
48         # image_np == [1, None, None, 3]
49         image_np_expanded = np.expand_dims(image_np, axis=0)
50         image_tensor = detection_graph.get_tensor_by_name(image_tensor:0)
51         boxes = detection_graph.get_tensor_by_name(detection_boxes:0)
52         scores = detection_graph.get_tensor_by_name(detection_scores:0)
53         classes = detection_graph.get_tensor_by_name(detection_classes:0)
54         num_detections = detection_graph.get_tensor_by_name(num_detections:0)
55         # Actual detection.
56         (boxes, scores, classes, num_detections) = sess.run(
57             [boxes, scores, classes, num_detections],
58             feed_dict=image_tensor: image_np_expanded)
59         # Visualization of the results of a detection.
60         vis_util.visualize_boxes_and_labels_on_image_array(
61               image_np,
62               np.squeeze(boxes),
63               np.squeeze(classes).astype(np.int32),
64               np.squeeze(scores),
65               category_index,
66               use_normalized_coordinates=True,
67               min_score_thresh=0.5,
68               line_thickness=1)
69         cv2.imshow(object detection, image_np)
70         cv2.imwrite("D:/run_result.png", image_np)
71         cv2.waitKey(0)
72         cv2.destroyAllWindows()

测试效果:

技术图片

 

以上是关于Demo 1基于object_detection API的行人检测 3:模型训练与测试的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow object_detection 使用

使用tensorflow object_detection API训练自己的数据遇到的问题及解决方法

如何安装 TensorFlow 2 和 object_detection 模块?

具有 tensorflow/models/object_detection 的特征金字塔网络

“object_detection.protos.SsdFeatureExtractor”没有名为“use_depthwise”的字段

protoc object_detection/protos/*.proto: 没有这样的文件或目录