第三十四节,目标检测之谷歌Object Detection API源码解析
Posted 大奥特曼打小怪兽
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第三十四节,目标检测之谷歌Object Detection API源码解析相关的知识,希望对你有一定的参考价值。
我们在第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)那一节我们介绍了如何使用谷歌Object Detection API进行目标检测,以及如何使用谷歌提供的目标检测模型训练自己的数据。在训练自己的数据集时,主要包括以下几步:
- 制作自己的数据集,注意这里数据集在进行标注时,需要按照一定的格式。然后调object_detection\\dataset_tools下对应的脚本生成tfrecord文件。如下图,如果我们想调用create_pascal_tf_record.py文件生成tfrecord文件,那么我们的数据集要和voc 2012数据集的标注方式一样。你也可以通过解读create_pascal_tf_record.py文件了解我们的数据集的标注方式。
- 下载我们所要使用的目标检测模型,进行预训练,不然从头开始训练时间成本会很高。
- 在object_detection/samples/configs文件夹下有一些配置文件,选择与我们所要使用的目标检测模型相对应的配置文件,并进行一些修改。
- 使用object_detection/train.py文件进行训练。
- 使用export_inference_graph.py脚本导出训练好的模型,并进行目标检测。
在这里我主要解析一下train.py文件的工作流程。
一 train.py文件解析
先附上源码:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Training executable for detection models. This executable is used to train DetectionModels. There are two ways of configuring the training job: 1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file can be specified by --pipeline_config_path. Example usage: ./train \\ --logtostderr \\ --train_dir=path/to/train_dir \\ --pipeline_config_path=pipeline_config.pbtxt 2) Three configuration files can be provided: a model_pb2.DetectionModel configuration file to define what type of DetectionModel is being trained, an input_reader_pb2.InputReader file to specify what training data will be used and a train_pb2.TrainConfig file to configure training parameters. Example usage: ./train \\ --logtostderr \\ --train_dir=path/to/train_dir \\ --model_config_path=model_config.pbtxt \\ --train_config_path=train_config.pbtxt \\ --input_config_path=train_input_config.pbtxt """ import functools import json import os import tensorflow as tf from object_detection import trainer from object_detection.builders import dataset_builder from object_detection.builders import graph_rewriter_builder from object_detection.builders import model_builder from object_detection.utils import config_util from object_detection.utils import dataset_util tf.logging.set_verbosity(tf.logging.INFO) flags = tf.app.flags flags.DEFINE_string(\'master\', \'\', \'Name of the TensorFlow master to use.\') flags.DEFINE_integer(\'task\', 0, \'task id\') flags.DEFINE_integer(\'num_clones\', 1, \'Number of clones to deploy per worker.\') flags.DEFINE_boolean(\'clone_on_cpu\', False, \'Force clones to be deployed on CPU. Note that even if \' \'set to False (allowing ops to run on gpu), some ops may \' \'still be run on the CPU if they have no GPU kernel.\') flags.DEFINE_integer(\'worker_replicas\', 1, \'Number of worker+trainer \' \'replicas.\') flags.DEFINE_integer(\'ps_tasks\', 0, \'Number of parameter server tasks. If None, does not use \' \'a parameter server.\') flags.DEFINE_string(\'train_dir\', \'\', \'Directory to save the checkpoints and training summaries.\') flags.DEFINE_string(\'pipeline_config_path\', \'\', \'Path to a pipeline_pb2.TrainEvalPipelineConfig config \' \'file. If provided, other configs are ignored\') flags.DEFINE_string(\'train_config_path\', \'\', \'Path to a train_pb2.TrainConfig config file.\') flags.DEFINE_string(\'input_config_path\', \'\', \'Path to an input_reader_pb2.InputReader config file.\') flags.DEFINE_string(\'model_config_path\', \'\', \'Path to a model_pb2.DetectionModel config file.\') FLAGS = flags.FLAGS def main(_): assert FLAGS.train_dir, \'`train_dir` is missing.\' if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir) if FLAGS.pipeline_config_path: configs = config_util.get_configs_from_pipeline_file( FLAGS.pipeline_config_path) if FLAGS.task == 0: tf.gfile.Copy(FLAGS.pipeline_config_path, os.path.join(FLAGS.train_dir, \'pipeline.config\'), overwrite=True) else: configs = config_util.get_configs_from_multiple_files( model_config_path=FLAGS.model_config_path, train_config_path=FLAGS.train_config_path, train_input_config_path=FLAGS.input_config_path) if FLAGS.task == 0: for name, config in [(\'model.config\', FLAGS.model_config_path), (\'train.config\', FLAGS.train_config_path), (\'input.config\', FLAGS.input_config_path)]: tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name), overwrite=True) model_config = configs[\'model\'] train_config = configs[\'train_config\'] input_config = configs[\'train_input_config\'] model_fn = functools.partial( model_builder.build, model_config=model_config, is_training=True) def get_next(config): return dataset_util.make_initializable_iterator( dataset_builder.build(config)).get_next() create_input_dict_fn = functools.partial(get_next, input_config) env = json.loads(os.environ.get(\'TF_CONFIG\', \'{}\')) cluster_data = env.get(\'cluster\', None) cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None task_data = env.get(\'task\', None) or {\'type\': \'master\', \'index\': 0} task_info = type(\'TaskSpec\', (object,), task_data) # Parameters for a single worker. ps_tasks = 0 worker_replicas = 1 worker_job_name = \'lonely_worker\' task = 0 is_chief = True master = \'\' if cluster_data and \'worker\' in cluster_data: # Number of total worker replicas include "worker"s and the "master". worker_replicas = len(cluster_data[\'worker\']) + 1 if cluster_data and \'ps\' in cluster_data: ps_tasks = len(cluster_data[\'ps\']) if worker_replicas > 1 and ps_tasks < 1: raise ValueError(\'At least 1 ps task is needed for distributed training.\') if worker_replicas >= 1 and ps_tasks > 0: # Set up distributed training. server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol=\'grpc\', job_name=task_info.type, task_index=task_info.index) if task_info.type == \'ps\': server.join() return worker_job_name = \'%s/task:%d\' % (task_info.type, task_info.index) task = task_info.index is_chief = (task_info.type == \'master\') master = server.target graph_rewriter_fn = None if \'graph_rewriter_config\' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs[\'graph_rewriter_config\'], is_training=True) trainer.train( create_input_dict_fn, model_fn, train_config, master, task, FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks, worker_job_name, is_chief, FLAGS.train_dir, graph_hook_fn=graph_rewriter_fn) if __name__ == \'__main__\': tf.app.run()
1、先定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。
flags = tf.app.flags flags.DEFINE_string(\'master\', \'\', \'Name of the TensorFlow master to use.\') flags.DEFINE_integer(\'task\', 0, \'task id\') flags.DEFINE_integer(\'num_clones\', 1, \'Number of clones to deploy per worker.\') flags.DEFINE_boolean(\'clone_on_cpu\', False, \'Force clones to be deployed on CPU. Note that even if \' \'set to False (allowing ops to run on gpu), some ops may \' \'still be run on the CPU if they have no GPU kernel.\') flags.DEFINE_integer(\'worker_replicas\', 1, \'Number of worker+trainer \' \'replicas.\') flags.DEFINE_integer(\'ps_tasks\', 0, \'Number of parameter server tasks. If None, does not use \' \'a parameter server.\') flags.DEFINE_string(\'train_dir\', \'\', \'Directory to save the checkpoints and training summaries.\') flags.DEFINE_string(\'pipeline_config_path\', \'\', \'Path to a pipeline_pb2.TrainEvalPipelineConfig config \' \'file. If provided, other configs are ignored\') flags.DEFINE_string(\'train_config_path\', \'\', \'Path to a train_pb2.TrainConfig config file.\') flags.DEFINE_string(\'input_config_path\', \'\', \'Path to an input_reader_pb2.InputReader config file.\') flags.DEFINE_string(\'model_config_path\', \'\', \'Path to a model_pb2.DetectionModel config file.\') FLAGS = flags.FLAGS
这里面有几个比较重要的参数,train_dir目录用于保存训练的模型和日志文件,pipeline_config_path用于指定pipeline_pb2.TrainEvalPipelineConfig配置文件的全路径(如果不指定指定这个参数,需要指定train_config_path,input_config_path,model_config_path配置文件,其实这三个文件就是把pipeline_pb2.TrainEvalPipelineConfig配置文件分成了三部分)。
2、再来看一下main函数,我们把它分成几部分来解读。
假设我们在控制台下的命令如下:
python train.py --train_dir voc/train_dir/ --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config
3、第一部分
assert FLAGS.train_dir, \'`train_dir` is missing.\' if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir) if FLAGS.pipeline_config_path: configs = config_util.get_configs_from_pipeline_file( FLAGS.pipeline_config_path) if FLAGS.task == 0: tf.gfile.Copy(FLAGS.pipeline_config_path, os.path.join(FLAGS.train_dir, \'pipeline.config\'), overwrite=True) else: configs = config_util.get_configs_from_multiple_files( model_config_path=FLAGS.model_config_path, train_config_path=FLAGS.train_config_path, train_input_config_path=FLAGS.input_config_path) if FLAGS.task == 0: for name, config in [(\'model.config\', FLAGS.model_config_path), (\'train.config\', FLAGS.train_config_path), (\'input.config\', FLAGS.input_config_path)]: tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name), overwrite=True)
因为我们传入了train_dir,pipeline_config_path参数,程序执行时会:
- 读取pipeline_config_path配置文件,返回一个dict,保存配置文件中`model`, `train_config`, `train_input_config`, `eval_config`, `eval_input_config`信息。
- 把pipeline_config_path配置文件复制到train_dir目录下,命名为pipeline.config
4、第二部分
model_config = configs[\'model\'] train_config = configs[\'train_config\'] input_config = configs[\'train_input_config\'] model_fn = functools.partial( model_builder.build, model_config=model_config, is_training=True) def get_next(config): return dataset_util.make_initializable_iterator( dataset_builder.build(config)).get_next() create_input_dict_fn = functools.partial(get_next, input_config)
- 变量model_config,train_config,input_config初始化
- model_builder.build函数,指定两个固定参数model_config,is_training并返回一个新的函数model_fn 。这个函数很重要,包括对目标检测模型的实现,后面会详细介绍。
- get_next函数,指定固定参数input_config。这个函数主要实现了tfrecord数据的读取,我们也放在后面介绍。
5、第三部分
env = json.loads(os.environ.get(\'TF_CONFIG\', \'{}\')) cluster_data = env.get(\'cluster\', None) cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None task_data = env.get(\'task\', None) or {\'type\': \'master\', \'index\': 0} task_info = type(\'TaskSpec\', (object,), task_data) # Parameters for a single worker. ps_tasks = 0 worker_replicas = 1 worker_job_name = \'lonely_worker\' task = 0 is_chief = True master = \'\' if cluster_data and \'worker\' in cluster_data: # Number of total worker replicas include "worker"s and the "master". worker_replicas = len(cluster_data[\'worker\']) + 1 if cluster_data and \'ps\' in cluster_data: ps_tasks = len(cluster_data[\'ps\']) if worker_replicas > 1 and ps_tasks < 1: raise ValueError(\'At least 1 ps task is needed for distributed training.\') if worker_replicas >= 1 and ps_tasks > 0: # Set up distributed training. server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol=\'grpc\', job_name=task_info.type, task_index=task_info.index) if task_info.type == \'ps\': server.join() return worker_job_name = \'%s/task:%d\' % (task_info.type, task_info.index) task = task_info.index is_chief = (task_info.type == \'master\') master = server.target
- 这部分代码主要是用来实现分布式部署训练的。如果想了解的请点击这里第八节,配置分布式TensorFlow。
6、第四部分
graph_rewriter_fn = None if \'graph_rewriter_config\' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs[\'graph_rewriter_config\'], is_training=True) trainer.train( create_input_dict_fn, model_fn, train_config, master, task, FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks, worker_job_name, is_chief, FLAGS.train_dir, graph_hook_fn=graph_rewriter_fn)
- 由于没有定义graph_rewriter_config,所以会直接执行trainer.train,开始读取数据,进行预处理后训练。
二 dataset_builder.build函数
先附上代码:
def build(input_reader_config, transform_input_data_fn=None,
batch_size=None, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
records. Applies a padded batch to the resulting dataset.
Args:
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
batch_size: Batch size. If None, batching is not performed.
max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
padding. If None, will use a dynamic shape.
num_classes: Number of classes in the dataset needed to compute shapes for
padding. If None, will use a dynamic shape.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. If None, will use dynamic shapes.
Returns:
A tf.data.Dataset based on the input_reader_config.
Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError(\'input_reader_config not of type \'
\'input_reader_pb2.InputReader.\')
if input_reader_config.WhichOneof(\'input_reader\') == \'tf_record_input_reader\':
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError(\'At least one input path must be specified in \'
\'`input_reader_config`.\')
label_map_proto_file = None
if input_reader_config.HasField(\'label_map_path\'):
label_map_proto_file = input_reader_config.label_map_path
#初始化需要解码的字段,以及解码对应字段的 handler
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file)
def process_fn(value):
processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed
# 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 对读取的数据解码数,预提取 input_reader_config.prefetch_size 条数据
dataset = dataset_util.read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config)
if batch_size:
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset
raise ValueError(\'Unsupported input_reader_config.\')
整个流程
- 获取训练集tfrecord文件路径,label_map_path文件路径,input_reader_config设置参数如下:
train_input_reader: { tf_record_input_reader { input_path: "voc/pascal_train.record" } label_map_path: "voc/pascal_label_map.pbtxt" }
- 初始化需要解码的字段,以及解码对应字段的 handler
- 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn (定义了数据的解码格式)对读取的数据解码,预提取 input_reader_config.prefetch_size 条数据
- 对数据集应用 tf.contrib.data.padded_batch_and_drop_remainder,如果不够一个 batch_size 就丢弃该部分数据
- 返回一个迭代器
三 model_builder.build函数
代码如下:
def build(model_config, is_training, add_summaries=True, add_background_class=True): """Builds a DetectionModel based on the model config. Args: model_config: A model.proto object containing the config for the desired DetectionModel. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tensorflow summaries in the model graph. add_background_class: Whether to add an implicit background class to one-hot encodings of groundtruth labels. Set to false if using groundtruth labels with an explicit background class or using multiclass scores instead of truth in the case of distillation. Ignored in the case of faster_rcnn. Returns: DetectionModel based on the config. Raises: ValueError: On invalid meta architecture or model. """ if not isinstance(model_config, model_pb2.DetectionModel): raise ValueError(\'model_config not of type model_pb2.DetectionModel.\') meta_architecture = model_config.WhichOneof(\'model\') if meta_architecture == \'ssd\': return _build_ssd_model(model_config.ssd, is_training, add_summaries, add_background_class) if meta_architecture == \'faster_rcnn\': return _build_faster_rcnn_model(model_config.faster_rcnn, is_training, add_summaries) raise ValueError(\'Unknown meta architecture: {}\'.format(meta_architecture))
先获取我们使用的目标检测模型,由于我们使用的是faster_rcnn_inception_resnet_v2,因此会调用_build_faster_rcnn_model函数,并且传入参数faster_rcnn,is_training,add_summaries。其中faster_rcnn的内容如下:
model { faster_rcnn { num_classes: 20 image_resizer { keep_aspect_ratio_resizer { min_dimension: 600 max_dimension: 1024 } } feature_extractor { type: \'faster_rcnn_inception_resnet_v2\' first_stage_features_stride: 8 } first_stage_anchor_generator { grid_anchor_generator { scales: [0.25, 0.5, 1.0, 2.0] aspect_ratios: [0.5, 1.0, 2.0] height_stride: 8 width_stride: 8 } } first_stage_atrous_rate: 2 first_stage_box_predictor_conv_hyperparams { op: CONV regularizer { l2_regularizer { weight: 0.0 } } initializer { truncated_normal_initializer { stddev: 0.01 } } } first_stage_nms_score_threshold: 0.0 first_stage_nms_iou_threshold: 0.7 first_stage_max_proposals: 300 first_stage_localization_loss_weight: 2.0 first_stage_objectness_loss_weight: 1.0 initial_crop_size: 17 maxpool_kernel_size: 1 maxpool_stride: 1 second_stage_box_predictor { mask_rcnn_box_predictor { use_dropout: false dropout_keep_probability: 1.0 fc_hyperparams { op: FC regularizer { l2_regularizer { weight: 0.0 } } initializer { variance_scaling_initializer { factor: 1.0 uniform: true mode: FAN_AVG } } } } } second_stage_post_processing { batch_non_max_suppression { score_threshold: 0.0 iou_threshold: 0.6 max_detections_per_class: 100 max_total_detections: 100 } score_converter: SOFTMAX } second_stage_localization_loss_weight: 2.0 second_stage_classification_loss_weight: 1.0 } }
我们再来看一下_build_faster_rcnn_model的源码:
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): """Builds a Faster R-CNN or R-FCN detection model based on the model config. Builds R-FCN model if the second_stage_box_predictor in the config is of type `rfcn_box_predictor` else builds a Faster R-CNN model. Args: frcnn_config: A faster_rcnn.proto object containing the config for the desired FasterRCNNMetaArch or RFCNMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. Returns: FasterRCNNMetaArch based on the config. Raises: ValueError: If frcnn_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = frcnn_config.num_classes image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer) feature_extractor = _build_faster_rcnn_feature_extractor( frcnn_config.feature_extractor, is_training, frcnn_config.inplace_batchnorm_update) number_of_stages = frcnn_config.number_of_stages first_stage_anchor_generator = anchor_generator_builder.build( frcnn_config.first_stage_anchor_generator) first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build( frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training) first_stage_box_predictor_kernel_size = ( frcnn_config.first_stage_box_predictor_kernel_size) first_stage_box_predictor_depth =学习笔记第三十四节课第三百三十四节,web爬虫讲解2—Scrapy框架爬虫—Scrapy爬取百度新闻,爬取Ajax动态生成的信息