mmdetection源码阅读笔记
Posted 木偶Roy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了mmdetection源码阅读笔记相关的知识,希望对你有一定的参考价值。
MMDet
工作需要简单看了一下源码,主要侧重训练和推理的部分,涉及到的是Registry、Runner和Hook部分。
核心库
核心库有MMDetection、MMSegmentation、MMDetection3d、MMCV。
MMDetection3d: 支持3d目标检测的模型和数据集
MMDetection & MMSegmentation: 支持常规的目标检测和分割的模型
MMCV:MM系列的基础库。支持了:
- Universal IO APIs
- Image/Video processing
- Image and annotation visualization
- Useful utilities (progress bar, timer, ...)
- PyTorch runner with hooking mechanism
- Various CNN architectures
- High-quality implementation of common CUDA ops
核心组件
使用MMDet进行训练和推理设计到的核心组件为MMCV/Registry和MMCV/Runner。还有一个非常重要的就是Hooks,可以在源码中看到开发人员大量的使用了hooks。
Registry
Registry是用来实例化MMDet中所有对象的工具,包括模型、数据集和Optimizer等等。
整体流程
Examples:
1. 注册类:类名->类 的映射
VOXEL_ENCODERS = Registry(\'voxel_encoder\')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
...
2. 实例化类对象:解析config,实例化类
def build_voxel_encoder(cfg):
"""Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS)
def build(cfg, registry, default_args=None):
...
return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
...
args = cfg.copy()
obj_cls = registry.get(obj_type)
return obj_cls(**args)
核心组件
MMCV/utils/regitry.py 主要是用一个self._module_dict来存储类名和类
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry(\'models\')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type=\'ResNet\'))
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
advanced useage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError(\'module must be a class, \'
f\'but got {type(module_class)}\')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f\'{name} is already registered \'
f\'in {self.name}\')
self._module_dict[name] = module_class
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry(\'backbone\')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry(\'backbone\')
>>> @backbones.register_module(name=\'mnet\')
>>> class MobileNet:
>>> pass
>>> backbones = Registry(\'backbone\')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f\'force must be a boolean, but got {type(force)}\')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
\'name must be either of None, an instance of str or a sequence\'
f\' of str, but got {type(name)}\')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
注册类
VOXEL_ENCODERS = Registry(\'voxel_encoder\')
MIDDLE_ENCODERS = Registry(\'middle_encoder\')
FUSION_LAYERS = Registry(\'fusion_layer\')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
build
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
# 注意这只是把一些细节的模块拼在一起
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f\'cfg must be a dict, but got {type(cfg)}\')
if \'type\' not in cfg:
if default_args is None or \'type\' not in default_args:
raise KeyError(
\'`cfg` or `default_args` must contain the key "type", \'
f\'but got {cfg}\\n{default_args}\')
if not isinstance(registry, Registry):
raise TypeError(\'registry must be an mmcv.Registry object, \'
f\'but got {type(registry)}\')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(\'default_args must be a dict or None, \'
f\'but got {type(default_args)}\')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop(\'type\')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f\'{obj_type} is not in the {registry.name} registry\')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f\'type must be a str or valid type, but got {type(obj_type)}\')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f\'{obj_cls.__name__}: {e}\')
Runner
MMCV/runner
从runner的目录结构可以看出,runner主要负责的就是实现checkpoint、train、val、optimizer和hooks。
mmdet3d、mmdet、mmcv中train的调用关系可以总结为:Mmdet3d/train.py -> mmdet/train_detector() -> mmcv/runner.run() -> mmcv/epoch_base_runner.py/train()
代码实现
mmcv/epoch_base_runner.py/train()
可以看出train函数已经开始进行dataloader遍历训练的过程了。其中也添加了很多hooks,这些都是在runner实例化的时候就已经register进runner中的,在EpochBasedRunner类的父类BaseRunner中有register_hook方法负责这件事。
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = \'train\'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook(\'before_train_epoch\')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook(\'before_train_iter\')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook(\'after_train_iter\')
self._iter += 1
self.call_hook(\'after_train_epoch\')
self._epoch += 1
mmcv/epoch_base_runner.py/run_iter()
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError(\'"batch_processor()" or "model.train_step()"\'
\'and "model.val_step()" must return a dict\')
if \'log_vars\' in outputs:
self.log_buffer.update(outputs[\'log_vars\'], outputs[\'num_samples\'])
self.outputs = outputs
Hook
- hook基类
mmcv/runner/hooks/hook.py
from mmcv.utils import Registry
HOOKS = Registry(\'hook\')
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
所有的hooks
__all__ = [
\'HOOKS\', \'Hook\', \'CheckpointHook\', \'ClosureHook\', \'LrUpdaterHook\',
\'OptimizerHook\', \'Fp16OptimizerHook\', \'IterTimerHook\',
\'DistSamplerSeedHook\', \'EmptyCacheHook\', \'LoggerHook\', \'MlflowLoggerHook\',
\'PaviLoggerHook\', \'TextLoggerHook\', \'TensorboardLoggerHook\',
\'NeptuneLoggerHook\', \'WandbLoggerHook\', \'DvcliveLoggerHook\',
\'MomentumUpdaterHook\', \'SyncBuffersHook\', \'EMAHook\', \'EvalHook\',
\'DistEvalHook\', \'ProfilerHook\'
]
-
hook嵌入入runner中
用一个priority queue存储实例化的hook对象,用来保证hook调用的优先级。优先级定义如下:
"""Hook priority levels. +------------+------------+ | Level | Value | +============+============+ | HIGHEST | 0 | +------------+------------+ | VERY_HIGH | 10 | +------------+------------+ | HIGH | 30 | +------------+------------+ | NORMAL | 50 | +------------+------------+ | LOW | 70 | +------------+------------+ | VERY_LOW | 90 | +------------+------------+ | LOWEST | 100 | +------------+------------+ """
runner中有两种hook注册方式:
- register_hook
- register_hook_from_cfg
这两个方法是在hook的基类mmcv/runner/base_runner.py中实现的,可以看到在register_hook中,倒序遍历队列,当找到一个比当前hook优先级高的hook时,就把当前的hook插入到这个hook的后面,如果找不到比它优先级高的就直接放在第一位。
def register_hook(self, hook, priority=\'NORMAL\'): """Register a hook into the hook list. The hook will be inserted into a priority queue, with the specified priority (See :class:`Priority` for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Args: hook (:obj:`Hook`): The hook to be registered. priority (int or str or :obj:`Priority`): Hook priority. Lower value means higher priority. """ assert isinstance(hook, Hook) if hasattr(hook, \'priority\'): raise ValueError(\'"priority" is a reserved attribute for hooks\') priority = get_priority(priority) hook.priority = priority # insert the hook to a sorted list inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook) def register_hook_from_cfg(self, hook_cfg): """Register a hook from its cfg. Args: hook_cfg (dict): Hook config. It should have at least keys \'type\' and \'priority\' indicating its type and priority. Notes: The specific hook class to register should not use \'type\' and \'priority\' arguments during initialization. """ hook_cfg = hook_cfg.copy() priority = hook_cfg.pop(\'priority\', \'NORMAL\') hook = mmcv.build_from_cfg(hook_cfg, HOOKS) self.register_hook(hook, priority=priority)
-
runner中调用hook
在priority queue中按顺序遍历hooks,确保优先级。根据实现可以看出每次调用call_hook的时候整个队列中的所有hook都会被调用到,并且执行自己实现的fn_name函数。
def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: getattr(hook, fn_name)(self)
-
Training前注册hook
实例化runner对象后,会去注册runner中用到的hooks
# register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get(\'momentum_config\', None)) if distributed: if isinstance(runner, EpochBasedRunner): runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: # Support batch_size > 1 in validation val_samples_per_gpu = cfg.data.val.pop(\'samples_per_gpu\', 1) if val_samples_per_gpu > 1: # Replace \'ImageToTensor\' to \'DefaultFormatBundle\' cfg.data.val.pipeline = replace_ImageToTensor( cfg.data.val.pipeline) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=val_samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get(\'evaluation\', {}) eval_cfg[\'by_epoch\'] = cfg.runner[\'type\'] != \'IterBasedRunner\' eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) # user-defined hooks if cfg.get(\'custom_hooks\', None): custom_hooks = cfg.custom_hooks assert isinstance(custom_hooks, list), \\ f\'custom_hooks expect list type, but got {type(custom_hooks)}\' for hook_cfg in cfg.custom_hooks: assert isinstance(hook_cfg, dict), \\ \'Each item in custom_hooks expects dict type, but got \' \\ f\'{type(hook_cfg)}\' hook_cfg = hook_cfg.copy() priority = hook_cfg.pop(\'priority\', \'NORMAL\') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority)
mmcv/base_runner.py
def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, timer_config=dict(type=\'IterTimerHook\'), custom_hooks_config=None): """Register default and custom hooks for training. Default and custom hooks include: Hooks Priority - LrUpdaterHook 10 - MomentumUpdaterHook 30 - OptimizerStepperHook 50 - CheckpointSaverHook 70 - IterTimerHook 80 - LoggerHook(s) 90 - CustomHook(s) 50 (default) """ self.register_lr_hook(lr_config) self.register_momentum_hook(momentum_config) self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_timer_hook(timer_config) self.register_logger_hooks(log_config) self.register_custom_hooks(custom_hooks_config)
-
训练推理中调用hook
def train(self, data_loader, **kwargs): self.model.train() self.mode = \'train\' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook(\'before_train_epoch\') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook(\'before_train_iter\') self.run_iter(data_batch, train_mode=True, **kwargs) self.call_hook(\'after_train_iter\') self._iter += 1 self.call_hook(\'after_train_epoch\') self._epoch += 1 @torch.no_grad() def val(self, data_loader, **kwargs): self.model.eval() self.mode = \'val\' self.data_loader = data_loader self.call_hook(\'before_val_epoch\') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook(\'before_val_iter\') self.run_iter(data_batch, train_mode=False) self.call_hook(\'after_val_iter\') self.call_hook(\'after_val_epoch\')
Examples:
训练的 self.call_hook(\'after_train_iter\') 调用应该会发生backward更新梯度的操作,也就是说应该在optimizer_hook中有个实现的after_train_iter方法,实现了.backward()和optimizer.step()
@HOOKS.register_module() class OptimizerHook(Hook): def __init__(self, grad_clip=None): self.grad_clip = grad_clip def clip_grads(self, params): params = list( filter(lambda p: p.requires_grad and p.grad is not None, params)) if len(params) > 0: return clip_grad.clip_grad_norm_(params, **self.grad_clip) def after_train_iter(self, runner): runner.optimizer.zero_grad() runner.outputs[\'loss\'].backward() if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({\'grad_norm\': float(grad_norm)}, runner.outputs[\'num_samples\']) runner.optimizer.step()
以上是关于mmdetection源码阅读笔记的主要内容,如果未能解决你的问题,请参考以下文章