Tensorflow object detection API 搭建属于自己的物体识别模型

Posted 嗨_放飞梦想

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow object detection API 搭建属于自己的物体识别模型相关的知识,希望对你有一定的参考价值。

一、下载Tensorflow object detection API工程源码

  网址:https://github.com/tensorflow/models,可通过Git下载,打开Git Bash,输入git clone https://github.com/tensorflow/models.git进行下载。

二、标记需要训练的图片

  ①、在第一步下载的工程文件models\\research\\object_detection目录下,建立一个my_test_images用来放测试test和训练train的文件夹。将需要识别的图片放在test和train中进行训练和测试。

  ②、到https://tzutalin.github.io/labelImg/下载labelImg工具,打开labelImg.exe,点击open dir,打开models\\research\\object_detection\\my_test_images\\test和train,对里面的所有照片标注完成,标注完成后保存为与图片名字一样的.xml文件。

  ③、在models\\research\\object_detection\\my_test_images文件夹下新建名字为xml_to_csv文件夹,在xml_to_csv文件夹下新建test_xml_to_csv.py和train_xml_to_csv.py文件。

    test_xml_to_csv.py代码如下:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Mar 13 21:50:27 2019
 4 
 5 @author: CFF
 6 """
 7 
 8 import os
 9 import glob
10 import pandas as pd
11 import xml.etree.ElementTree as ET
12 
13 os.chdir(\'C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\test\')
14 path = \'C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\test\'
15 
16 def xml_to_csv(path):
17     xml_list = []
18     for xml_file in glob.glob(path + \'/*.xml\'):
19         tree = ET.parse(xml_file)
20         root = tree.getroot()
21         for member in root.findall(\'object\'):
22             value = (root.find(\'filename\').text,
23                      int(root.find(\'size\')[0].text),
24                      int(root.find(\'size\')[1].text),
25                      member[0].text,
26                      int(member[4][0].text),
27                      int(member[4][1].text),
28                      int(member[4][2].text),
29                      int(member[4][3].text)
30                      )
31             xml_list.append(value)
32     column_name = [\'filename\', \'width\', \'height\', \'class\', \'xmin\', \'ymin\', \'xmax\', \'ymax\']
33     xml_df = pd.DataFrame(xml_list, columns=column_name)
34     return xml_df
35 def main():
36     image_path = path
37     xml_df = xml_to_csv(image_path)
38     xml_df.to_csv(\'cat_test.csv\', index=None)#cat_test.csv可以改为自己的文件名
39     print(\'Successfully converted xml to csv.\')
40 main()
 

 用Spyder打开test_xml_to_csv.py,点击编译,在C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\test目录下生成一个cat_test.csv文件,可用Excel打开。

 同理,train_xml_to_csv.py代码如下:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Mar 13 21:48:33 2019
 4 
 5 @author: CFF
 6 """
 7 import os
 8 import glob
 9 import pandas as pd
10 import xml.etree.ElementTree as ET
11 
12 os.chdir(\'C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\train\')
13 path = \'C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\train\'
14 
15 def xml_to_csv(path):
16     xml_list = []
17     for xml_file in glob.glob(path + \'/*.xml\'):
18         tree = ET.parse(xml_file)
19         root = tree.getroot()
20         for member in root.findall(\'object\'):
21             value = (root.find(\'filename\').text,
22                      int(root.find(\'size\')[0].text),
23                      int(root.find(\'size\')[1].text),
24                      member[0].text,
25                      int(member[4][0].text),
26                      int(member[4][1].text),
27                      int(member[4][2].text),
28                      int(member[4][3].text)
29                      )
30             xml_list.append(value)
31     column_name = [\'filename\', \'width\', \'height\', \'class\', \'xmin\', \'ymin\', \'xmax\', \'ymax\']
32     xml_df = pd.DataFrame(xml_list, columns=column_name)
33     return xml_df
34 def main():
35     image_path = path
36     xml_df = xml_to_csv(image_path)
37     xml_df.to_csv(\'cat_train.csv\', index=None)
38     print(\'Successfully converted xml to csv.\')
39 main()

   用Spyder打开train_xml_to_csv.py,点击编译,在C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\my_test_images\\\\train目录下生成一个cat_train.csv文件,可用Excel打开。

三、将cat_train.csv和cat_test.csv文件转换为train.record和test.record数据集

  ①、先将cat_train.csv和cat_test.csv文件放在C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\\\\data文件夹下。

  ②、在C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection文件夹下新建一个images文件夹,放入训练和测试的图片。

  ③、用Spyder在C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection文件夹下新建一个generate_tfrecord.py文件,generate_tfrecord.py代码如下:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Mar 13 21:56:20 2019
 4 
 5 @author: CFF
 6 """
 7 
 8 """
 9 Usage:
10   # From tensorflow/models/
11   # Create train data:
12   python generate_tfrecord.py --csv_input=data/cat_train.csv    --output_path=data/train.record
13   # Create test data:
14   python generate_tfrecord.py --csv_input=data/cat_test.csv     --output_path=data/test.record
15 """
16 
17 import os
18 import io
19 import pandas as pd
20 import tensorflow as tf
21 
22 from PIL import Image
23 from object_detection.utils import dataset_util
24 from collections import namedtuple, OrderedDict
25 
26 os.chdir(\'C:\\\\Users\\\\CFF\\\\Desktop\\\\models\\\\research\\\\object_detection\')
27 
28 flags = tf.app.flags
29 flags.DEFINE_string(\'csv_input\', \'\', \'Path to the CSV input\')
30 flags.DEFINE_string(\'output_path\', \'\', \'Path to output TFRecord\')
31 FLAGS = flags.FLAGS
32 
33 
34 # TO-DO replace this with label map
35 def class_text_to_int(row_label):#标签类型,根据实际情况写
36     if row_label == \'cat\':
37         return 1
38     else:
39         None
40 
41 
42 def split(df, group):
43     data = namedtuple(\'data\', [\'filename\', \'object\'])
44     gb = df.groupby(group)
45     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
46 
47 
48 def create_tf_example(group, path):
49     with tf.gfile.GFile(os.path.join(path, \'{}\'.format(group.filename)), \'rb\') as fid:
50         encoded_jpg = fid.read()
51     encoded_jpg_io = io.BytesIO(encoded_jpg)
52     image = Image.open(encoded_jpg_io)
53     width, height = image.size
54 
55     filename = group.filename.encode(\'utf8\')
56     image_format = b\'jpg\'
57     xmins = []
58     xmaxs = []
59     ymins = []
60     ymaxs = []
61     classes_text = []
62     classes = []
63 
64     for index, row in group.object.iterrows():
65         xmins.append(row[\'xmin\'] / width)
66         xmaxs.append(row[\'xmax\'] / width)
67         ymins.append(row[\'ymin\'] / height)
68         ymaxs.append(row[\'ymax\'] / height)
69         classes_text.append(row[\'class\'].encode(\'utf8\'))
70         classes.append(class_text_to_int(row[\'class\']))
71 
72     tf_example = tf.train.Example(features=tf.train.Features(feature={
73         \'image/height\': dataset_util.int64_feature(height),
74         \'image/width\': dataset_util.int64_feature(width),
75         \'image/filename\': dataset_util.bytes_feature(filename),
76         \'image/source_id\': dataset_util.bytes_feature(filename),
77         \'image/encoded\': dataset_util.bytes_feature(encoded_jpg),
78         \'image/format\': dataset_util.bytes_feature(image_format),
79         \'image/object/bbox/xmin\': dataset_util.float_list_feature(xmins),
80         \'image/object/bbox/xmax\': dataset_util.float_list_feature(xmaxs),
81         \'image/object/bbox/ymin\': dataset_util.float_list_feature(ymins),
82         \'image/object/bbox/ymax\': dataset_util.float_list_feature(ymaxs),
83         \'image/object/class/text\': dataset_util.bytes_list_feature(classes_text),
84         \'image/object/class/label\': dataset_util.int64_list_feature(classes),
85     }))
86     return tf_example
87 def main(_):
88     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
89     path = os.path.join(os.getcwd(), \'images\')
90     examples = pd.read_csv(FLAGS.csv_input)
91     grouped = split(examples, \'filename\')
92     for group in grouped:
93         tf_example = create_tf_example(group, path)
94         writer.write(tf_example.SerializeToString())
95     writer.close()
96     output_path = os.path.join(os.getcwd(), FLAGS.output_path)
97     print(\'Successfully created the TFRecords: {}\'.format(output_path))
98 if __name__ == \'__main__\':
99     tf.app.run()

打开Anaconda Prompt,分别输入python generate_tfrecord.py --csv_input=data/cat_train.csv --output_path=data/train.record和python generate_tfrecord.py --csv_input=data/cat_test.csv --output_path=data/test.record,在data文件夹下将生成train.record和test.record文件。(注意:出现tensorflow object detection API 验证时报No module named \'object_detection\'时,在安装路径Anaconda3\\Lib\\site-packages下,新建tensorflow_model.pth文件,内容为模型文件路径:如C:\\Users\\CFF\\Desktop\\mymodels\\research 和C:\\Users\\CFF\\Desktop\\mymodels\\research\\slim

  ③、在data文件夹下,新建一个cat_label_map.pbtxt文件,用Spyder打开,内容为:

1 item {
2   id: 1
3   name: \'cat\'
4 }

可根据分类数量进行修改。

四、在C:\\Users\\CFF\\Desktop\\models\\research\\object_detection文件夹下,建立一个training文件夹。

  到https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs下载ssd_mobilenet_v1_coco.config模型,在training文件夹下新建一个文本文档,命名为ssd_mobilenet_v1_coco.config,内容如下:

  1 # SSD with Mobilenet v1 configuration for MSCOCO Dataset.
  2 # Users should configure the fine_tune_checkpoint field in the train config as
  3 # well as the label_map_path and input_path fields in the train_input_reader and
  4 # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
  5 # should be configured.
  6 
  7 model {
  8   ssd {
  9     num_classes: 1 #根据实际情况填写分类数量
 10     box_coder {
 11       faster_rcnn_box_coder {
 12         y_scale: 10.0
 13         x_scale: 10.0
 14         height_scale: 5.0
 15         width_scale: 5.0
 16       }
 17     }
 18     matcher {
 19       argmax_matcher {
 20         matched_threshold: 0.5
 21         unmatched_threshold: 0.5
 22         ignore_thresholds: false
 23         negatives_lower_than_unmatched: true
 24         force_match_for_each_row: true
 25       }
 26     }
 27     similarity_calculator {
 28       iou_similarity {
 29       }
 30     }
 31     anchor_generator {
 32       ssd_anchor_generator {
 33         num_layers: 6
 34         min_scale: 0.2
 35         max_scale: 0.95
 36         aspect_ratios: 1.0
 37         aspect_ratios: 2.0
 38         aspect_ratios: 0.5
 39         aspect_ratios: 3.0
 40         aspect_ratios: 0.3333
 41       }
 42     }
 43     image_resizer {
 44       fixed_shape_resizer {
 45         height: 300
 46         width: 300
 47       }
 48     }
 49     box_predictor {
 50       convolutional_box_predictor {
 51         min_depth: 0
 52         max_depth: 0
 53         num_layers_before_predictor: 0
 54         use_dropout: false
 55         dropout_keep_probability: 0.8
 56         kernel_size: 1
 57         box_code_size: 4
 58         apply_sigmoid_to_scores: false
 59         conv_hyperparams {
 60           activation: RELU_6,
 61           regularizer {
 62             l2_regularizer {
 63               weight: 0.00004
 64             }
 65           }
 66           initializer {
 67             truncated_normal_initializer {
 68               stddev: 0.03
 69               mean: 0.0
 70             }
 71           }
 72           batch_norm {
 73             train: true,
 74             scale: true,
 75             center: true,
 76             decay: 0.9997,
 77             epsilon: 0.001,
 78           }
 79         }
 80       }
 81     }
 82     feature_extractor {
 83       type: \'ssd_mobilenet_v1\'
 84       min_depth: 16
 85       depth_multiplier: 1.0
 86       conv_hyperparams {
 87         activation: RELU_6,
 88         regularizer {
 89           l2_regularizer {
 90             weight: 0.00004
 91           }
 92         }
 93         initializer {
 94           truncated_normal_initializer {
 95             stddev: 0.03
 96             mean: 0.0
 97           }
 98         }
 99         batch_norm {
100           train: true,
101           scale: true,
102           center: true,
103           decay: 0.9997,
104           epsilon: 0.001,
105         }
106       }
107     }
108     loss {
109       classification_loss {
110         weighted_sigmoid {
111         }
112       }
113       localization_loss {
114         weighted_smooth_l1 {
115         }
116       }
117       hard_example_miner {
118         num_hard_examples: 3000
119         iou_threshold: 0.99
120         loss_type: CLASSIFICATION
121         max_negatives_per_positive: 3
122         min_negatives_per_image: 0
123       }
124       classification_weight: 1.0
125       localization_weight: 1.0
126     }
127     normalize_loss_by_num_matches: true
128     post_processing {
129       batch_non_max_suppression {
130         score_threshold: 1e-8
131         iou_threshold: 0.6
132         max_detections_per_class: 100
133         max_total_detections: 100
134       }
135       score_converter: SIGMOID
136     }
137   }
138 }
139 
140 train_config: {
141   batch_size: 1
142   optimizer {
143     rms_prop_optimizer: {
144       learning_rate: {
145         exponential_decay_learning_rate {
146           initial_learning_rate: 0.004
147           decay_steps: 800720
148           decay_factor: 0.95
149         }
150       }
151       momentum_optimizer_value: 0.9
152       decay: 0.9
153       epsilon: 1.0
154     }
155   }
156 #  fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
157 #  from_detection_checkpoint: true
158   # Note: The below line limits the training process to 200K steps, which we
159   # empirically found to be sufficient enough to train the pets dataset. This
160   # effectively bypasses the learning rate schedule (the learning rate will
161   # never decay). Remove the below line to train indefinitely.
162   num_steps: 200000
163   data_augmentation_options {
164     random_horizontal_flip {
165     }
166   }
167   data_augmentation_options {
168     ssd_random_crop {
169     }
170   }
171 }
172 
173 train_input_reader: {
174   tf_record_input_reader {
175     input_path:"data/train.record"
176   }
177   label_map_path:"data/cat_label_map.pbtxt"
178 }
179 
180 eval_config: {
181   num_examples: 8000
182   # Note: The below line limits the evaluation process to 10 evaluations.
183   # Remove the below line to evaluate indefinitely.
184   max_evals: 10
185 }
186 
187 eval_input_reader: {
188   tf_record_input_reader {
189     input_path:"data/test.record"
190   }
191   label_map_path:"data/cat_label_map.pbtxt"
192   shuffle: false
193   num_readers: 1
194 }

  其中,num_classes: 1 是根据实际情况填写分类数量,input_path:"data/train.record"和input_path:"data/test.record"为之前在data文件加下生成的train.record文件和tets.record文件。label_map_path:"data/cat_label_map.pbtxt"也是之前在data中生成的文件。

五、训练模型

  ①、在models/research路径下,输入protoc object_detection/protos/*.proto --python_out=.命令,将所有的.proto文件生成.py文件。

  ②、打开Anaconda Prompt,通过命令cd C:\\Users\\CFF\\Desktop\\models\\research\\object_detection到该目录下,运行以下命令:

1 python model_main.py --pipeline_config_path=training/ssd_mobilenet_v1_coco.config \\ --model_dir=training \\ --num_train_steps=50000 \\ --num_eval_steps=2000 \\

  开始训练。训练一段时间后,可以在C:\\Users\\CFF\\Desktop\\models\\research\\object_detection,通过tensorboard --logdir=training命令,根据返回的网址在浏览器中打开,可以看到最新的图表。

六、测试自己的图片

  ①、在C:\\Users\\CFF\\Desktop\\models\\research\\object_detection\\test_images文件夹下放需要识别的图片,用image1-imageN命名。

  ②、打开Anaconda Prompt,通过命令cd C:\\Users\\CFF\\Desktop\\models\\research\\object_detection到该目录下,输入python export_inference_graph.py \\ --input_type image_tensor \\ --pipeline_config_path training/ssd_mobilement_v1_coco.config \\ --trained_checkpoint_prefix training/model.ckpt-9278 \\ --output_directory cat_detection。其中model.ckpt-9278为训练的最后步数,可在training文件下看到。在cat_detection

下生成以下文件。

 

  ③、打开Anaconda Prompt,通过命令cd C:\\Users\\CFF\\Desktop\\models\\research\\object_detection到该目录下,输入jupyter notebook回车,打开交互环境。下载对应的Python文件object_detection_tutorial.py到本地。

  ④、用Spyder打开object_detection_tutorial.py文件,代码如下:

  1 # coding: utf-8
  2 
  3 # # Object Detection Demo
  4 # Welcome to the object detection inference walkthrough!  This notebook will walk you step by step through the process of using a pre-trained model to detect objects in an image. Make sure to follow the [installation instructions](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md) before you start.
  5 
  6 # # Imports
  7 
  8 # In[1]:
  9 
 10 
 11 import numpy as np
 12 import os
如何安装 TensorFlow 2 和 object_detection 模块?

TensorFlow Object Detection API

TensorFlow object_detection 使用

TensorFlow object detection API

TensorFlow object detection API应用一

TensorFlow使用object detection训练并识别自己的模型