mmdet3d training 流程

Posted ZLTJohn

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了mmdet3d training 流程相关的知识,希望对你有一定的参考价值。

一般大家的pytorch训练代码都比较简洁,mmdet3d为了支持扩展性,把代码进行了很多的抽象和封装,大大降低了可读性。现在简单理一下其training的代码执行逻辑。
实际使用的时候肯定是train几个epoch之后eval一次的,这里只考虑training

训练前配置

配置一下data的config

dataset_type = 'CustomWaymoDataset'
data_root = '/localdata_ssd/waymo_ssd_train_only/kitti_format/' 
data = dict(
    samples_per_gpu=1,
    workers_per_gpu=4,
    train=dict(
        type='RepeatDataset',
        times=1,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            num_views=num_views,
            ann_file=data_root + 'waymo_infos_train.pkl',
            split='training',
            pipeline=train_pipeline,
            modality=input_modality,
            classes=class_names,
            test_mode=False,
            # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
            # and box_type_3d='Depth' in sunrgbd and scannet dataset.
            box_type_3d='LiDAR',
            # load one frame every five frames
            load_interval=5)),

训练

因为大家共享GPU,难免一台机器有其他人在使用部分gpu,因此需要export CUDA_VISIBLE_DEVICES=1,2,3
为dist_train.sh指定配置文件和gpu数量,就能开始训练了
bash tools/dist_train.sh projects/configs/detr3d/detr3d_res101_gridmask_waymo.py 8

tools/dist_train.sh

python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\
    $(dirname "$0")/train.py $CONFIG --launcher pytorch $@:3

分布式运行tools/train.py

tools/train.py

def parse_args()就是传递命令行参数,注意到这里很多参数和config.py里设置的是重复的,习惯直接把参数设置在config.py里。

custom plugin import

如果我们自己写了自定义的模型、dataset等代码,要以plugin的形式嵌入mmdet3d,那么在这里要先import进来。我的理解,import module的时候,python会把整个module的代码跑一遍:首先找到__init__.py,我们提前在里面写好了要import哪些submodule,然后python继续import submodule,比如说,detr3d的detector:

from mmdet.models import DETECTORS
@DETECTORS.register_module()
class Detr3D(MVXTwoStageDetector):
    """Detr3D."""
    ...

这个时候,因为有@DETECTORS.register_module()的存在,detr3d就会被mmlab里的registry注册到detectors这个module底下,这样对于mmdet的所有代码来说,detr3d这个类都是可访问的了。之后用mmdet3d.models.build_model()方法,他也能找到这个类,实例化一个detr3d出来。不过需要注意的是,如果把customdataset注册到mmdet里,mmdet3d的dataset builder会找不到我写的dataset类,必须放到mmdet3d里才行。我还没找出原因。还需要进一步理解python的机制。

initialize

之后就会根据config和args做初始化的工作,比如初始化logger,dataset,model等。最后调用mmdet3d.apis.train_model(),里面进一步调用下面的mmdet.apis.train_detector

mmdetection\\mmdet\\apis\\train.py:train_detector()

进一步初始化,初始化dataloader,optimizer,runner等。最后调用runner.run(data_loaders, cfg.workflow),进入真正的训练。
runner用于管理整个training procedure,具体原理见mmdet官方教程(我看的知乎)。

mmcv\\mmcv\\runner\\epoch_based_runner.py

runner挂了很多hook,用于定义训练不同阶段的行为,比如训练前后要保存什么信息,每个epoch或者iter前后要做些什么,比如learning rate调整和根据gradient来optimize weights。通常来说每train几个epoch就要eval一次,我们配置参数之后,runner也会帮你做。

runner.train()

定义了一个epoch会做哪些事,epoch前后会call对应的hook,iter前后也会。

 self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

runner.run_iter()

核心就是

outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)

这样,data就最终进入model的train_step了,比如fcos3d,就会进入继承自父类mmdet.models.detectors.BaseDetector的train_step,进而进行forward,loss等操作,注意这里传进去的optimizer,对于fcos3d来说是没用的,可能将来或者其他地方会用到。

以上是关于mmdet3d training 流程的主要内容,如果未能解决你的问题,请参考以下文章

mmdet3d training 流程

mmdet3d+waymo test/evaluation流程

mmdet3d+waymo test/evaluation流程

mmdet3d+waymo test/evaluation流程

mmdet3d+waymo 踩坑+验证环境正确性流程

mmdet3d+waymo 踩坑+验证环境正确性流程