mmdetection源码阅读笔记

Posted 木偶Roy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了mmdetection源码阅读笔记相关的知识,希望对你有一定的参考价值。

MMDet

工作需要简单看了一下源码,主要侧重训练和推理的部分,涉及到的是Registry、Runner和Hook部分。

核心库

核心库有MMDetection、MMSegmentation、MMDetection3d、MMCV。

MMDetection3d: 支持3d目标检测的模型和数据集

MMDetection & MMSegmentation: 支持常规的目标检测和分割的模型

MMCV:MM系列的基础库。支持了:

  • Universal IO APIs
  • Image/Video processing
  • Image and annotation visualization
  • Useful utilities (progress bar, timer, ...)
  • PyTorch runner with hooking mechanism
  • Various CNN architectures
  • High-quality implementation of common CUDA ops

核心组件

使用MMDet进行训练和推理设计到的核心组件为MMCV/Registry和MMCV/Runner。还有一个非常重要的就是Hooks,可以在源码中看到开发人员大量的使用了hooks。

Registry

Registry是用来实例化MMDet中所有对象的工具,包括模型、数据集和Optimizer等等。

整体流程

Examples:

1. 注册类:类名->类 的映射
VOXEL_ENCODERS = Registry(\'voxel_encoder\')

@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
  ...
  
2. 实例化类对象:解析config,实例化类
def build_voxel_encoder(cfg):
    """Build voxel encoder."""
    return build(cfg, VOXEL_ENCODERS)

def build(cfg, registry, default_args=None):
  	...
    return build_from_cfg(cfg, registry, default_args)
  
def build_from_cfg(cfg, registry, default_args=None):
    ...
    args = cfg.copy()
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)

核心组件

MMCV/utils/regitry.py 主要是用一个self._module_dict来存储类名和类

class Registry:
    """A registry to map strings to classes.
    Registered object could be built from registry.
    Example:
        >>> MODELS = Registry(\'models\')
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> resnet = MODELS.build(dict(type=\'ResNet\'))
    Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
    advanced useage.
    Args:
        name (str): Registry name.
        build_func(func, optional): Build function to construct instance from
            Registry, func:`build_from_cfg` is used if neither ``parent`` or
            ``build_func`` is specified. If ``parent`` is specified and
            ``build_func`` is not given,  ``build_func`` will be inherited
            from ``parent``. Default: None.
        parent (Registry, optional): Parent registry. The class registered in
            children registry could be built from parent. Default: None.
        scope (str, optional): The scope of registry. It is the key to search
            for children registry. If not specified, scope will be the name of
            the package where class is defined, e.g. mmdet, mmcls, mmseg.
            Default: None.
    """
    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None
        
	def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError(\'module must be a class, \'
                            f\'but got {type(module_class)}\')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f\'{name} is already registered \'
                               f\'in {self.name}\')
            self._module_dict[name] = module_class

    def register_module(self, name=None, force=False, module=None):
        """Register a module.
        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.
        Example:
            >>> backbones = Registry(\'backbone\')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass
            >>> backbones = Registry(\'backbone\')
            >>> @backbones.register_module(name=\'mnet\')
            >>> class MobileNet:
            >>>     pass
            >>> backbones = Registry(\'backbone\')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)
        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f\'force must be a boolean, but got {type(force)}\')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # raise the error ahead of time
        if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
            raise TypeError(
                \'name must be either of None, an instance of str or a sequence\'
                f\'  of str, but got {type(name)}\')

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

注册类

VOXEL_ENCODERS = Registry(\'voxel_encoder\')
MIDDLE_ENCODERS = Registry(\'middle_encoder\')
FUSION_LAYERS = Registry(\'fusion_layer\')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):

build

def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        # 注意这只是把一些细节的模块拼在一起
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.
    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f\'cfg must be a dict, but got {type(cfg)}\')
    if \'type\' not in cfg:
        if default_args is None or \'type\' not in default_args:
            raise KeyError(
                \'`cfg` or `default_args` must contain the key "type", \'
                f\'but got {cfg}\\n{default_args}\')
    if not isinstance(registry, Registry):
        raise TypeError(\'registry must be an mmcv.Registry object, \'
                        f\'but got {type(registry)}\')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError(\'default_args must be a dict or None, \'
                        f\'but got {type(default_args)}\')

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop(\'type\')
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f\'{obj_type} is not in the {registry.name} registry\')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f\'type must be a str or valid type, but got {type(obj_type)}\')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f\'{obj_cls.__name__}: {e}\')

Runner

MMCV/runner

从runner的目录结构可以看出,runner主要负责的就是实现checkpoint、train、val、optimizer和hooks。

mmdet3d、mmdet、mmcv中train的调用关系可以总结为:Mmdet3d/train.py -> mmdet/train_detector() -> mmcv/runner.run() -> mmcv/epoch_base_runner.py/train()

代码实现

mmcv/epoch_base_runner.py/train()

可以看出train函数已经开始进行dataloader遍历训练的过程了。其中也添加了很多hooks,这些都是在runner实例化的时候就已经register进runner中的,在EpochBasedRunner类的父类BaseRunner中有register_hook方法负责这件事。

def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = \'train\'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook(\'before_train_epoch\')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook(\'before_train_iter\')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook(\'after_train_iter\')
            self._iter += 1

        self.call_hook(\'after_train_epoch\')
        self._epoch += 1

mmcv/epoch_base_runner.py/run_iter()

def run_iter(self, data_batch, train_mode, **kwargs):
        if self.batch_processor is not None:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=train_mode, **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError(\'"batch_processor()" or "model.train_step()"\'
                            \'and "model.val_step()" must return a dict\')
        if \'log_vars\' in outputs:
            self.log_buffer.update(outputs[\'log_vars\'], outputs[\'num_samples\'])
        self.outputs = outputs

Hook

  1. hook基类

mmcv/runner/hooks/hook.py

from mmcv.utils import Registry

HOOKS = Registry(\'hook\')


class Hook:

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

    def is_last_epoch(self, runner):
        return runner.epoch + 1 == runner._max_epochs

    def is_last_iter(self, runner):
        return runner.iter + 1 == runner._max_iters

所有的hooks

__all__ = [
    \'HOOKS\', \'Hook\', \'CheckpointHook\', \'ClosureHook\', \'LrUpdaterHook\',
    \'OptimizerHook\', \'Fp16OptimizerHook\', \'IterTimerHook\',
    \'DistSamplerSeedHook\', \'EmptyCacheHook\', \'LoggerHook\', \'MlflowLoggerHook\',
    \'PaviLoggerHook\', \'TextLoggerHook\', \'TensorboardLoggerHook\',
    \'NeptuneLoggerHook\', \'WandbLoggerHook\', \'DvcliveLoggerHook\',
    \'MomentumUpdaterHook\', \'SyncBuffersHook\', \'EMAHook\', \'EvalHook\',
    \'DistEvalHook\', \'ProfilerHook\'
]
  1. hook嵌入入runner中

    用一个priority queue存储实例化的hook对象,用来保证hook调用的优先级。优先级定义如下:

        """Hook priority levels.
        +------------+------------+
        | Level      | Value      |
        +============+============+
        | HIGHEST    | 0          |
        +------------+------------+
        | VERY_HIGH  | 10         |
        +------------+------------+
        | HIGH       | 30         |
        +------------+------------+
        | NORMAL     | 50         |
        +------------+------------+
        | LOW        | 70         |
        +------------+------------+
        | VERY_LOW   | 90         |
        +------------+------------+
        | LOWEST     | 100        |
        +------------+------------+
        """
    

    runner中有两种hook注册方式:

    1. register_hook
    2. register_hook_from_cfg

    这两个方法是在hook的基类mmcv/runner/base_runner.py中实现的,可以看到在register_hook中,倒序遍历队列,当找到一个比当前hook优先级高的hook时,就把当前的hook插入到这个hook的后面,如果找不到比它优先级高的就直接放在第一位。

    def register_hook(self, hook, priority=\'NORMAL\'):
            """Register a hook into the hook list.
            The hook will be inserted into a priority queue, with the specified
            priority (See :class:`Priority` for details of priorities).
            For hooks with the same priority, they will be triggered in the same
            order as they are registered.
            Args:
                hook (:obj:`Hook`): The hook to be registered.
                priority (int or str or :obj:`Priority`): Hook priority.
                    Lower value means higher priority.
            """
            assert isinstance(hook, Hook)
            if hasattr(hook, \'priority\'):
                raise ValueError(\'"priority" is a reserved attribute for hooks\')
            priority = get_priority(priority)
            hook.priority = priority
            # insert the hook to a sorted list
            inserted = False
            for i in range(len(self._hooks) - 1, -1, -1):
                if priority >= self._hooks[i].priority:
                    self._hooks.insert(i + 1, hook)
                    inserted = True
                    break
            if not inserted:
                self._hooks.insert(0, hook)
    
        def register_hook_from_cfg(self, hook_cfg):
            """Register a hook from its cfg.
            Args:
                hook_cfg (dict): Hook config. It should have at least keys \'type\'
                  and \'priority\' indicating its type and priority.
            Notes:
                The specific hook class to register should not use \'type\' and
                \'priority\' arguments during initialization.
            """
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop(\'priority\', \'NORMAL\')
            hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
            self.register_hook(hook, priority=priority)
    
  2. runner中调用hook

    在priority queue中按顺序遍历hooks,确保优先级。根据实现可以看出每次调用call_hook的时候整个队列中的所有hook都会被调用到,并且执行自己实现的fn_name函数。

    def call_hook(self, fn_name):
            """Call all hooks.
            Args:
                fn_name (str): The function name in each hook to be called, such as
                    "before_train_epoch".
            """
            for hook in self._hooks:
                getattr(hook, fn_name)(self)
    
  3. Training前注册hook

    实例化runner对象后,会去注册runner中用到的hooks

        # register hooks
        runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                       cfg.checkpoint_config, cfg.log_config,
                                       cfg.get(\'momentum_config\', None))
        if distributed:
            if isinstance(runner, EpochBasedRunner):
                runner.register_hook(DistSamplerSeedHook())
        # register eval hooks
        if validate:
            # Support batch_size > 1 in validation
            val_samples_per_gpu = cfg.data.val.pop(\'samples_per_gpu\', 1)
            if val_samples_per_gpu > 1:
                # Replace \'ImageToTensor\' to \'DefaultFormatBundle\'
                cfg.data.val.pipeline = replace_ImageToTensor(
                    cfg.data.val.pipeline)
            val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
            val_dataloader = build_dataloader(
                val_dataset,
                samples_per_gpu=val_samples_per_gpu,
                workers_per_gpu=cfg.data.workers_per_gpu,
                dist=distributed,
                shuffle=False)
            eval_cfg = cfg.get(\'evaluation\', {})
            eval_cfg[\'by_epoch\'] = cfg.runner[\'type\'] != \'IterBasedRunner\'
            eval_hook = DistEvalHook if distributed else EvalHook
            runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
    
        # user-defined hooks
        if cfg.get(\'custom_hooks\', None):
            custom_hooks = cfg.custom_hooks
            assert isinstance(custom_hooks, list), \\
                f\'custom_hooks expect list type, but got {type(custom_hooks)}\'
            for hook_cfg in cfg.custom_hooks:
                assert isinstance(hook_cfg, dict), \\
                    \'Each item in custom_hooks expects dict type, but got \' \\
                    f\'{type(hook_cfg)}\'
                hook_cfg = hook_cfg.copy()
                priority = hook_cfg.pop(\'priority\', \'NORMAL\')
                hook = build_from_cfg(hook_cfg, HOOKS)
                runner.register_hook(hook, priority=priority)
    

    mmcv/base_runner.py

        def register_training_hooks(self,
                                    lr_config,
                                    optimizer_config=None,
                                    checkpoint_config=None,
                                    log_config=None,
                                    momentum_config=None,
                                    timer_config=dict(type=\'IterTimerHook\'),
                                    custom_hooks_config=None):
            """Register default and custom hooks for training.
            Default and custom hooks include:
              Hooks                 Priority
            - LrUpdaterHook         10
            - MomentumUpdaterHook   30
            - OptimizerStepperHook  50
            - CheckpointSaverHook   70
            - IterTimerHook         80
            - LoggerHook(s)         90
            - CustomHook(s)         50 (default)
            """
            self.register_lr_hook(lr_config)
            self.register_momentum_hook(momentum_config)
            self.register_optimizer_hook(optimizer_config)
            self.register_checkpoint_hook(checkpoint_config)
            self.register_timer_hook(timer_config)
            self.register_logger_hooks(log_config)
            self.register_custom_hooks(custom_hooks_config)
    
  4. 训练推理中调用hook

    def train(self, data_loader, **kwargs):
            self.model.train()
            self.mode = \'train\'
            self.data_loader = data_loader
            self._max_iters = self._max_epochs * len(self.data_loader)
            self.call_hook(\'before_train_epoch\')
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            for i, data_batch in enumerate(self.data_loader):
                self._inner_iter = i
                self.call_hook(\'before_train_iter\')
                self.run_iter(data_batch, train_mode=True, **kwargs)
                self.call_hook(\'after_train_iter\')
                self._iter += 1
    
            self.call_hook(\'after_train_epoch\')
            self._epoch += 1
            
     @torch.no_grad()
        def val(self, data_loader, **kwargs):
            self.model.eval()
            self.mode = \'val\'
            self.data_loader = data_loader
            self.call_hook(\'before_val_epoch\')
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            for i, data_batch in enumerate(self.data_loader):
                self._inner_iter = i
                self.call_hook(\'before_val_iter\')
                self.run_iter(data_batch, train_mode=False)
                self.call_hook(\'after_val_iter\')
    
            self.call_hook(\'after_val_epoch\')
    

    Examples:

    训练的 self.call_hook(\'after_train_iter\') 调用应该会发生backward更新梯度的操作,也就是说应该在optimizer_hook中有个实现的after_train_iter方法,实现了.backward()和optimizer.step()

    @HOOKS.register_module()
    class OptimizerHook(Hook):
    
        def __init__(self, grad_clip=None):
            self.grad_clip = grad_clip
    
        def clip_grads(self, params):
            params = list(
                filter(lambda p: p.requires_grad and p.grad is not None, params))
            if len(params) > 0:
                return clip_grad.clip_grad_norm_(params, **self.grad_clip)
    
        def after_train_iter(self, runner):
            runner.optimizer.zero_grad()
            runner.outputs[\'loss\'].backward()
            if self.grad_clip is not None:
                grad_norm = self.clip_grads(runner.model.parameters())
                if grad_norm is not None:
                    # Add grad norm to the logger
                    runner.log_buffer.update({\'grad_norm\': float(grad_norm)},
                                             runner.outputs[\'num_samples\'])
            runner.optimizer.step()
    

以上是关于mmdetection源码阅读笔记的主要内容,如果未能解决你的问题,请参考以下文章

MMDetection的学习笔记

Linux 0.11源码阅读笔记-总览

源码阅读笔记 - 1 MSVC2015中的std::sort

Yii源码阅读笔记(三十五)

Dubbo源码阅读笔记4

mmdetection训练自己的COCO数据集