Demo 1基于object_detection API的行人检测 3:模型训练与测试
Posted mxiaoy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Demo 1基于object_detection API的行人检测 3:模型训练与测试相关的知识,希望对你有一定的参考价值。
训练准备
模型选择
选择ssd_mobilenet_v2_coco模型,下载地址(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md),解压到./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29.
修改object_detection配置文件
进入目录./Pedestrian_Detection/models/research/object_detection/samples/configs 找到对应的模型配置文件ssd_mobilenet_v2_coco.config修改配置文件。
根据提示信息:
1、第9行,检测类别把90改为1,因为我们只检测行人,只有一个类别。
2、修改除提示外所有的
2.1、第一个(156行)是我们所需模型的路径,即上一步下载好的:./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt
2.2、第二个(175行)是train.record文件的路径,上一篇中我们准备好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_train.record
2.3、第三个(177行)是上一篇准备好的label_map.pbtxt的路径:./Pedestrian_Detection/project/pedestrian_train/data/label_map.pbtxt
2.4、第四个(189行)是eval.record文件的路径,上一篇中我们准备好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_eval.record
2.5、第五个(191行)同2.3
这样config文件就修改完成了。然后把它放到:./Pedestrian_Detection/project/pedestrian_train/models目录下。最后在该目录下创建两个文件夹:train 和 eval,用于储存训练和验证的记录。
开始训练
打开命令行窗口
在research目录下输入:
(dl) D:\\Study\\dl\\Pedestrian_Detection\\models\\research>python object_detection/legacy/train.py --train_dir=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\train --pipeline_config_path=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\ssd_mobilenet_v2_coco.config --logtostderr
即可开始训练。
这里我们选择2000次之后,按ctrl+c结束训练。训练的详细信息可通过tensorboard来进行查看(这里不再赘述)。
查看我们的训练记录:
导出模型
这里我们选择第2391次的训练数据来生成模型。
把下图4个文件放到:./Pedestrian_Detection/pedestrian_data/model 目录下
在命令行窗口下输入命令:
(dl) D:\\Study\\dl\\Pedestrian_Detection\\models\\research>python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=D:\\Study\\dl\\Pedestrian_Detection\\project\\pedestrian_train\\models\\ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix=D:\\Study\\dl\\Pedestrian_Detection\\pedestrian_data\\model\\model.ckpt-2391 --output_directory=D:\\Study\\dl\\Pedestrian_Detection\\pedestrian_data\\test
查看发现对应的目录下已经生成了一系列的模型文件:
测试模型
测试代码:
1 import os 2 import sys 3 4 import cv2 5 import numpy as np 6 import tensorflow as tf 7 8 sys.path.append("..") 9 from object_detection.utils import label_map_util 10 from object_detection.utils import visualization_utils as vis_util 11 12 ################################################## 13 14 ################################################## 15 16 # Path to frozen detection graph 17 PATH_TO_CKPT = ‘D:/Study/dl/Pedestrian_Detection/pedestrian_data/test/frozen_inference_graph.pb‘ 18 19 # List of the strings that is used to add correct label for each box. 20 PATH_TO_LABELS = os.path.join(‘D:/Study/dl/Pedestrian_Detection/project/pedestrian_train/data‘, ‘label_map.pbtxt‘) 21 22 NUM_CLASSES = 1 23 detection_graph = tf.Graph() 24 with detection_graph.as_default(): 25 od_graph_def = tf.GraphDef() 26 with tf.gfile.GFile(PATH_TO_CKPT, ‘rb‘) as fid: 27 serialized_graph = fid.read() 28 od_graph_def.ParseFromString(serialized_graph) 29 tf.import_graph_def(od_graph_def, name=‘‘) 30 31 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 32 categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 33 category_index = label_map_util.create_category_index(categories) 34 35 36 def load_image_into_numpy_array(image): 37 (im_width, im_height) = image.size 38 return np.array(image.getdata()).reshape( 39 (im_height, im_width, 3)).astype(np.uint8) 40 41 42 with detection_graph.as_default(): 43 with tf.Session(graph=detection_graph) as sess: 44 image_np = cv2.imread("D:/Study/dl/Pedestrian_Detection/project/test_images/3600.jpg") 45 # image_np = cv2.imread("D:/images/pedestrain.png") 46 cv2.imshow("input", image_np) 47 print(image_np.shape) 48 # image_np == [1, None, None, 3] 49 image_np_expanded = np.expand_dims(image_np, axis=0) 50 image_tensor = detection_graph.get_tensor_by_name(‘image_tensor:0‘) 51 boxes = detection_graph.get_tensor_by_name(‘detection_boxes:0‘) 52 scores = detection_graph.get_tensor_by_name(‘detection_scores:0‘) 53 classes = detection_graph.get_tensor_by_name(‘detection_classes:0‘) 54 num_detections = detection_graph.get_tensor_by_name(‘num_detections:0‘) 55 # Actual detection. 56 (boxes, scores, classes, num_detections) = sess.run( 57 [boxes, scores, classes, num_detections], 58 feed_dict=image_tensor: image_np_expanded) 59 # Visualization of the results of a detection. 60 vis_util.visualize_boxes_and_labels_on_image_array( 61 image_np, 62 np.squeeze(boxes), 63 np.squeeze(classes).astype(np.int32), 64 np.squeeze(scores), 65 category_index, 66 use_normalized_coordinates=True, 67 min_score_thresh=0.5, 68 line_thickness=1) 69 cv2.imshow(‘object detection‘, image_np) 70 cv2.imwrite("D:/run_result.png", image_np) 71 cv2.waitKey(0) 72 cv2.destroyAllWindows()
测试效果:
以上是关于Demo 1基于object_detection API的行人检测 3:模型训练与测试的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow object_detection 使用
使用tensorflow object_detection API训练自己的数据遇到的问题及解决方法
如何安装 TensorFlow 2 和 object_detection 模块?
具有 tensorflow/models/object_detection 的特征金字塔网络
“object_detection.protos.SsdFeatureExtractor”没有名为“use_depthwise”的字段