openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理

Posted 非晚非晚

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理相关的知识,希望对你有一定的参考价值。

文章目录

pointpillar相关的其它文章链接如下:

  1. 【论文阅读】CVPR 2019| PointPillars: 基于点云的快速编码目标检测框架(Fast Encoders for Object Detection from Point Clouds)
  2. OpenPCDet v0.5版本的安装与测试
  3. openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理
  4. openpcdet之pointpillar代码阅读——第二篇:网络结构
  5. openpcdet之pointpillar代码阅读——第三篇:损失函数的计算

1. 数据增强

数据增强部分,相对比较清晰,整体流程如下所示。后续openpcdet也出了一些新的数据增强方法,不过目前本人暂时还没有使用。

数据增强部分代码在:pcdet/datasets/augmentor/data_augmentor.py

1.1 gt数据采集——gt_sampling

该模块思路很简单,就是为了丰富训练数据,也就是将其它帧gt的点云以及box放入待训练帧中的空余位置。下面是这部分的配置文件,官方这部分训练了3种类型。

            - NAME: gt_sampling
              USE_ROAD_PLANE: True
              DB_INFO_PATH:
                  - kitti_dbinfos_train.pkl
              PREPARE: 
                 filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'],
                 filter_by_difficulty: [-1],
              

              SAMPLE_GROUPS: ['Car:15','Pedestrian:15', 'Cyclist:15']
              NUM_POINT_FEATURES: 4
              DATABASE_WITH_FAKELIDAR: False
              REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
              LIMIT_WHOLE_SCENE: False

首先对采样的gt进行最小点过滤。代码注释如下:

class DataAugmentor(object):
    def __init__(self, root_path, augmentor_configs, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.logger = logger
        
        self.data_augmentor_queue = []
        # 读取数据增强部分配置文件
        aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \\
            else augmentor_configs.AUG_CONFIG_LIST
        #逐个读取数据增强部分
        for cur_cfg in aug_config_list:
            if not isinstance(augmentor_configs, list):
            	#不用数据增强的列表DISABLE_AUG_LIST
                if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
                    continue
            #使用partial,所以此刻只是把数据增强方法加入队列(data_dict=0)
            # 执行数据增加的函数,并加入至data_augmentor_queue
            cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
            self.data_augmentor_queue.append(cur_augmentor)
    
    #gt数据采集部分
    def gt_sampling(self, config=None):
        db_sampler = database_sampler.DataBaseSampler(
            root_path=self.root_path,
            sampler_cfg=config,
            class_names=self.class_names,
            logger=self.logger
        )
        return db_sampler

其中DataBaseSampler的代码如下:

class DataBaseSampler(object):
    def __init__(self, root_path, sampler_cfg, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.sampler_cfg = sampler_cfg
        self.logger = logger
        self.db_infos = 
        #按照类别分类
        for class_name in class_names:
            self.db_infos[class_name] = []

        # use_shared_memory = false
        self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
        
        for db_info_path in sampler_cfg.DB_INFO_PATH:
            db_info_path = self.root_path.resolve() / db_info_path
            #按照类别加入数据各自的db数据
            with open(str(db_info_path), 'rb') as f:
                infos = pickle.load(f)
                [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
        #执行最小点过滤和困难点过滤,我这里只用了filter_by_min_points过滤
        for func_name, val in sampler_cfg.PREPARE.items():
            self.db_infos = getattr(self, func_name)(self.db_infos, val)
        
        self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None

        self.sample_groups =  #sample_num、pointer和indices
        self.sample_class_num =  #sample_num
        self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) #False

        for x in sampler_cfg.SAMPLE_GROUPS:
            class_name, sample_num = x.split(':')
            if class_name not in class_names:
                continue
            self.sample_class_num[class_name] = sample_num
            self.sample_groups[class_name] = 
                'sample_num': sample_num,
                'pointer': len(self.db_infos[class_name]),
                'indices': np.arange(len(self.db_infos[class_name]))
            
 #最小点过滤函数 
  def filter_by_min_points(self, db_infos, min_gt_points_list):
        for name_num in min_gt_points_list:
            #对每个类别单独过滤
            name, min_num = name_num.split(':')
            min_num = int(min_num)
            if min_num > 0 and name in db_infos.keys():
                filtered_infos = []
                for info in db_infos[name]:
                    #box内大于min_num的保留
                    if info['num_points_in_gt'] >= min_num:
                        filtered_infos.append(info)

                if self.logger is not None:
                    self.logger.info('Database filter by min points %s: %d => %d' %
                                     (name, len(db_infos[name]), len(filtered_infos)))
                db_infos[name] = filtered_infos

        return db_infos

1.2 全局翻转——random_world_flip

这部分配置如下,这部分的意义为使points和gt_boxes进行 X轴的全局翻转

        - NAME: random_world_flip
          ALONG_AXIS_LIST: ['x']

官方这部分只做了X轴的对称翻转,翻转概率为50%。

    def random_world_flip(self, data_dict=None, config=None):
    	#data_dict如果为空数据,暂时返回
        if data_dict is None:
            return partial(self.random_world_flip, config=config)
        gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
        for cur_axis in config['ALONG_AXIS_LIST']:
            assert cur_axis in ['x', 'y']
            #执行具体的翻转函数
            gt_boxes, points = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
                gt_boxes, points,
            )
        
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict
 # x轴对称翻转函数
def random_flip_along_x(gt_boxes, points):
    """
    Args:
        gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
        points: (M, 3 + C)
    Returns:
    """
    # 0.5的概率翻转
    enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
    if enable:
        gt_boxes[:, 1] = -gt_boxes[:, 1] #box,修改x的值
        gt_boxes[:, 6] = -gt_boxes[:, 6] #box,修改heading角
        points[:, 1] = -points[:, 1] #点云翻转
        
        if gt_boxes.shape[1] > 7:
            gt_boxes[:, 8] = -gt_boxes[:, 8]
    
    return gt_boxes, points

1.3 全局旋转——random_world_rotation

这部分配置如下,这部分的意义为使points和gt_boxes进行 绕Z轴的旋转 波动。

        - NAME: random_world_rotation
          WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] #值为弧度

需要说明的是,这里Z轴旋转的角度为弧度。

    def random_world_rotation(self, data_dict=None, config=None):
    	#data_dict如果为空数据,暂时返回
        if data_dict is None:
            return partial(self.random_world_rotation, config=config)
        rot_range = config['WORLD_ROT_ANGLE']
        if not isinstance(rot_range, list):
            rot_range = [-rot_range, rot_range]
        gt_boxes, points = augmentor_utils.global_rotation(
            data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range
        )
# 旋转函数
def global_rotation(gt_boxes, points, rot_range):
    """
    Args:
        gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
        points: (M, 3 + C),
        rot_range: [min, max]
    Returns:
    """
    #旋转噪声,从最小至最大中随机取值
    noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
    #点云和box的旋转
    points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
    gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0]
    gt_boxes[:, 6] += noise_rotation
    if gt_boxes.shape[1] > 7:
        gt_boxes[:, 7:9] = common_utils.rotate_points_along_z(
            np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[np.newaxis, :, :],
            np.array([noise_rotation])
        )[0][:, 0:2]

    return gt_boxes, points
  
# 具体调用的旋转函数为
def rotate_points_along_z(points, angle):
    """
    Args:
        points: (B, N, 3 + C)
        angle: (B), angle along z-axis, angle increases x ==> y
    Returns:

    """
    points, is_numpy = check_numpy_to_torch(points)
    angle, _ = check_numpy_to_torch(angle)

    cosa = torch.cos(angle)
    sina = torch.sin(angle)
    zeros = angle.new_zeros(points.shape[0])
    ones = angle.new_ones(points.shape[0])
    rot_matrix = torch.stack((
        cosa,  sina, zeros,
        -sina, cosa, zeros,
        zeros, zeros, ones
    ), dim=1).view(-1, 3, 3).float()
    points_rot = torch.matmul(points[:, :, 0:3], rot_matrix)
    points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1)
    return points_rot.numpy() if is_numpy else points_rot

1.4 全局尺度变换——random_world_scaling

这部分配置如下,这部分的意义为使points和gt_boxes进行 尺度的缩放

        - NAME: random_world_scaling
          WORLD_SCALE_RANGE: [0.95, 1.05]
    def random_world_scaling(self, data_dict=None, config=None):
    	#data_dict如果为空数据,暂时返回
        if data_dict is None:
            return partial(self.random_world_scaling, config=config)
         #调用尺度函数
        gt_boxes, points = augmentor_utils.global_scaling(
            data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE']
        )
        
        data_dict['gt_boxes'] = gt_boxes
        data_dict['points'] = points
        return data_dict

#尺度函数
def global_scaling(gt_boxes, points, scale_range):
    """
    Args:
        gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading]
        points: (M, 3 + C),
        scale_range: [min, max]
    Returns:
    """
    #变换尺度太小,则不用了
    if scale_range[1] - scale_range[0] < 1e-3:
        return gt_boxes, points
    #尺度因子
    noise_scale = np.random.uniform(scale_range[0], scale_range[1])
    points[:, :3] *= noise_scale
    gt_boxes[:, :6] *= noise_scale

    return gt_boxes, points

2. 数据处理

数据处理有3个函数,分别为限制范围、随机打乱点云范围和点云至voxels(或者pillars)的变换。这部分的流程如下:

这部分的代码位于:pcdet/datasets/processor/data_processor.py

class DataProcessor(object):
    def __init__(self, processor_configs, point_cloud_range, training, num_point_features):
        self.point_cloud_range = point_cloud_range #数据范围
        self.training = training
         #给定每个点云的特征维度,这里是x,y,z,r 其中r是激光雷达反射强度
        self.num_point_features = num_point_features
        self.mode = 'train' if training else 'test'
        #grid或voxel或pillar的size
        self.grid_size = self.voxel_size = None
        self.data_processor_queue = []

        self.voxel_generator = None
		#依次加入三个数据处理
        for cur_cfg in processor_configs:
            cur_processor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
            self.data_processor_queue.append(cur_processor)

2.1 数据范围限制

点云的mask和box的mask是分别制作的。

    def mask_points_and_boxes_outside_range(self, data_dict=None, config=None):
        if data_dict is None:
            return partial(self.mask_points_and_boxes_outside_range, config=config)

        if data_dict.get('points', None) is not None:
        	#根据配置文件的范围,做一个点云mask
            mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range)
            data_dict['points'] = data_dict['points'][mask]

        if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training:
        	#根据配置文件的范围,做一个gt_box的mask
            mask = box_utils.mask_boxes_outside_range_numpy(
                data_dict['gt_boxes'], self.point_cloud_range, min_num_corners=config.get('min_num_corners', 1)
            )
            data_dict['gt_boxes'] = data_dict['gt_boxes'][mask]
        return data_dict

#制作点云的mask的函数
def mask_points_by_range(points, limit_range):
    mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \\
           & (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4])
    return mask
# gt_box的mask函数
def mask_boxes_outside_range_numpy(boxes, limit_range, min_num_corners=1):
    """
    Args:
        boxes: (N, 7) [x, y, z, dx, dy, dz, heading, ...], (x, y, z) is the box center
        limit_range: [minx, miny, minz, maxx, maxy, maxz]
        min_num_corners:

    Returns:

    """
    if boxes.shape[1] > 7:
        boxes = boxes[:, 0:7]
    corners = boxes_to_corners_3d(boxes)  # (N, 8, 3)
    mask = ((corners >= limit_range[0:3]) & (corners <= limit_range[3:6openpcdet之pointpillar代码阅读——第二篇:网络结构

openpcdet之pointpillar代码阅读——第三篇:损失函数的计算

openpcdet之pointpillar代码阅读——第三篇:损失函数的计算

openpcdet之pointpillar代码阅读——第三篇:损失函数的计算

openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理

openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理