mmdet3d training 流程
Posted ZLTJohn
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了mmdet3d training 流程相关的知识,希望对你有一定的参考价值。
一般大家的pytorch训练代码都比较简洁,mmdet3d为了支持扩展性,把代码进行了很多的抽象和封装,大大降低了可读性。现在简单理一下其training的代码执行逻辑。
实际使用的时候肯定是train几个epoch之后eval一次的,这里只考虑training
训练前配置
配置一下data的config
dataset_type = 'CustomWaymoDataset'
data_root = '/localdata_ssd/waymo_ssd_train_only/kitti_format/'
data = dict(
samples_per_gpu=1,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
num_views=num_views,
ann_file=data_root + 'waymo_infos_train.pkl',
split='training',
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
test_mode=False,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5)),
训练
因为大家共享GPU,难免一台机器有其他人在使用部分gpu,因此需要export CUDA_VISIBLE_DEVICES=1,2,3
为dist_train.sh指定配置文件和gpu数量,就能开始训练了
bash tools/dist_train.sh projects/configs/detr3d/detr3d_res101_gridmask_waymo.py 8
tools/dist_train.sh
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \\
$(dirname "$0")/train.py $CONFIG --launcher pytorch $@:3
分布式运行tools/train.py
tools/train.py
def parse_args()就是传递命令行参数,注意到这里很多参数和config.py里设置的是重复的,习惯直接把参数设置在config.py里。
custom plugin import
如果我们自己写了自定义的模型、dataset等代码,要以plugin的形式嵌入mmdet3d,那么在这里要先import进来。我的理解,import module的时候,python会把整个module的代码跑一遍:首先找到__init__.py,我们提前在里面写好了要import哪些submodule,然后python继续import submodule,比如说,detr3d的detector:
from mmdet.models import DETECTORS
@DETECTORS.register_module()
class Detr3D(MVXTwoStageDetector):
"""Detr3D."""
...
这个时候,因为有@DETECTORS.register_module()的存在,detr3d就会被mmlab里的registry注册到detectors这个module底下,这样对于mmdet的所有代码来说,detr3d这个类都是可访问的了。之后用mmdet3d.models.build_model()方法,他也能找到这个类,实例化一个detr3d出来。不过需要注意的是,如果把customdataset注册到mmdet里,mmdet3d的dataset builder会找不到我写的dataset类,必须放到mmdet3d里才行。我还没找出原因。还需要进一步理解python的机制。
initialize
之后就会根据config和args做初始化的工作,比如初始化logger,dataset,model等。最后调用mmdet3d.apis.train_model(),里面进一步调用下面的mmdet.apis.train_detector
mmdetection\\mmdet\\apis\\train.py:train_detector()
进一步初始化,初始化dataloader,optimizer,runner等。最后调用runner.run(data_loaders, cfg.workflow)
,进入真正的训练。
runner用于管理整个training procedure,具体原理见mmdet官方教程(我看的知乎)。
mmcv\\mmcv\\runner\\epoch_based_runner.py
runner挂了很多hook,用于定义训练不同阶段的行为,比如训练前后要保存什么信息,每个epoch或者iter前后要做些什么,比如learning rate调整和根据gradient来optimize weights。通常来说每train几个epoch就要eval一次,我们配置参数之后,runner也会帮你做。
runner.train()
定义了一个epoch会做哪些事,epoch前后会call对应的hook,iter前后也会。
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
runner.run_iter()
核心就是
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
这样,data就最终进入model的train_step了,比如fcos3d,就会进入继承自父类mmdet.models.detectors.BaseDetector的train_step,进而进行forward,loss等操作,注意这里传进去的optimizer,对于fcos3d来说是没用的,可能将来或者其他地方会用到。
以上是关于mmdet3d training 流程的主要内容,如果未能解决你的问题,请参考以下文章
mmdet3d+waymo test/evaluation流程
mmdet3d+waymo test/evaluation流程