使用自己的数据集,测试mmrotate新网络rotated_rtmdet,旋转目标检测
Posted 东东就是我
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用自己的数据集,测试mmrotate新网络rotated_rtmdet,旋转目标检测相关的知识,希望对你有一定的参考价值。
1.安装
!!!!一定不要安装mmrotate
1.版本需求
mmcv 2.0.0rc2
mmdet 3.0.0rc5
mmengine 0.5.0
不用安装mmcv-full
- 下载mmrotate 1.x 源码 (不要下载默认的master,因为新的网络只在1.x版本中)
2.制作数据集
因为需要的是dota格式的数据集
DOTA 格式的注解 txt 文件:
184 2875 193 2923 146 2932 137 2885 plane 0
66 2095 75 2142 21 2154 11 2107 plane 0
...
每行代表一个对象,并将其记录为一个 10 维数组 A
。
A[0:8]
: 多边形的格式(x1, y1, x2, y2, x3, y3, x4, y4)
。A[8]
: 类别A[9]
: 困难
我的数据是用labelme标注的json文件,所以先把json转txt
import json
import os
file_path='hailuo'
json_names=os.listdir(file_path)
for name in json_names:
if name.endswith('json'):
json_path=os.path.join(file_path,name)
path2 = name.split('.')[0]+'.txt'
# path2 = '1.txt'
file2 = open(path2, 'w+')
with open(json_path) as f:
data = json.load(f)
for box in data["shapes"]:
_one=box['points']
res = [str(x) for j in _one for x in j]
if len(res)==10:
continue
rest =' '.join(res)
rest=rest+' box 0\\n'
file2.write(rest)
把所有的txt文件放在annfiles文件夹,所有的图片放在images文件夹
3. 修改配置文件
- G:\\research\\mmrotate-1.x\\configs\\rotated_rtmdet_base_\\dota_rr.py
# dataset settings
dataset_type = 'DOTADataset'
data_root = '../data/' # 路径
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(
type='mmdet.RandomFlip',
prob=0.75,
direction=['horizontal', 'vertical', 'diagonal']),
dict(
type='RandomRotate',
prob=0.5,
angle_range=180,
rect_obj_labels=[9, 11]),
dict(type='mmdet.PackDetInputs')
]
val_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
# avoid bboxes being resized
dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'
)) # 删除了'scale_factor'
]
test_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape')) # 删除了'scale_factor'
]
train_dataloader = dict(
batch_size=2,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=None,
pin_memory=False,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='trainval/annfiles/', # txt路径
data_prefix=dict(img_path='trainval/images/'), # 图片路径
img_shape=(512, 640),
filter_cfg=dict(filter_empty_gt=True),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='trainval/annfiles/', # txt路径
data_prefix=dict(img_path='trainval/images/'), # 图片路径
img_shape=(512, 640),
test_mode=True,
pipeline=val_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='DOTAMetric', metric='mAP')
test_evaluator = val_evaluator
- G:\\research\\mmrotate-1.x\\mmrotate\\datasets\\dota.py
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
from typing import List, Tuple
from mmengine.dataset import BaseDataset
from mmrotate.registry import DATASETS
@DATASETS.register_module()
class DOTADataset(BaseDataset):
"""DOTA-v1.0 dataset for detection.
Note: ``ann_file`` in DOTADataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In DOTADataset,
it is the path of a folder containing XML files.
Args:
img_shape (tuple[int]): The shape of images. Due to the huge size
of the remote sensing image, we will cut it into slices with
the same shape. Defaults to (1024, 1024).
diff_thr (int): The difficulty threshold of ground truth. Bboxes
with difficulty higher than it will be ignored. The range of this
value should be non-negative integer. Defaults to 100.
"""
METAINFO =
'classes':
('box',),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42),]
# class 修改
def __init__(self,
img_shape: Tuple[int, int] = (1024, 1024),
diff_thr: int = 100,
**kwargs) -> None:
self.img_shape = img_shape
self.diff_thr = diff_thr
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
cls_map = c: i
for i, c in enumerate(self.metainfo['classes'])
# in mmdet v2.0 label is 0-based
data_list = []
if self.ann_file == '':
img_files = glob.glob(
osp.join(self.data_prefix['img_path'], '*.png'))
for img_path in img_files:
data_info =
data_info['img_path'] = img_path
img_name = osp.split(img_path)[1]
data_info['file_name'] = img_name
img_id = img_name[:-4]
data_info['img_id'] = img_id
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]
instance = dict(bbox=[], bbox_label=[], ignore_flag=0)
data_info['instances'] = [instance]
data_list.append(data_info)
return data_list
else:
txt_files = glob.glob(osp.join(self.ann_file, '*.txt'))
if len(txt_files) == 0:
raise ValueError('There is no txt file in '
f'self.ann_file')
for txt_file in txt_files:
data_info =
img_id = osp.split(txt_file)[1][:-4]
data_info['img_id'] = img_id
img_name = img_id + '.png'
data_info['file_name'] = img_name
data_info['img_path'] = osp.join(self.data_prefix['img_path'],
img_name)
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]
instances = []
with open(txt_file) as f:
s = f.readlines()
for si in s:
instance =
bbox_info = si.split()
instance['bbox'] = [float(i) for i in bbox_info[:8]]
cls_name = bbox_info[8]
if len(cls_name)>3:
print(11)
continue
instance['bbox_label'] = cls_map[cls_name]
difficulty = int(bbox_info[9])
if difficulty > self.diff_thr:
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_list.append(data_info)
return data_list
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \\
if self.filter_cfg is not None else False
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
if filter_empty_gt and len(data_info['instances']) == 0:
continue
valid_data_infos.append(data_info)
return valid_data_infos
def get_cat_ids(self, idx: int) -> List[int]:
"""Get DOTA category ids by index.
Args:
idx (int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]
@DATASETS.register_module()
class DOTAv15Dataset(DOTADataset):
"""DOTA-v1.5 dataset for detection.
Note: ``ann_file`` in DOTAv15Dataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In DOTAv15Dataset,
it is the path of a folder containing XML files.
"""
METAINFO =
'classes':
('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
'harbor', 'swimming-pool', 'helicopter', 'container-crane'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0), (255, 0, 255),
(0, 255, 255), (255, 193, 193), (0, 51, 153),
(255, 250, 205), (0, 139, 139), (255, 255, 0),
(147, 116, 116), (0, 0, 255), (220, 20, 60)]
@DATASETS.register_module()
class DOTAv2Dataset(DOTADataset):
"""DOTA-v2.0 dataset for detection.
Note: ``ann_file`` in DOTAv2Dataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In DOTAv2Dataset,
it is the path of a folder containing XML files.
"""
METAINFO =
'classes':
('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
'harbor', 'swimming-pool', 'helicopter', 'container-crane', 'airport',
'helipad'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0), (255, 0, 255),
(0, 255, 255), (255, 193, 193), (0, 51, 153),
(255, 250, 205), (0, 139, 139), (255, 255, 0),
(147, 116, 116), (0, 0, 255), (220, 20, 60), (119, 11, 32),
(0, 0, 142)]
- G:\\research\\mmrotate-1.x\\configs\\rotated_rtmdet\\rotated_rtmdet_l-3x-dota_ms.py
num_classes=1 修改
_base_ = [
'./_base_/default_runtime.py', './_base_/schedule_3x.py',
'./_base_/dota_rr_ms.py'
]
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-l_8xb256-rsb-a1-600e_in1k-6a760974.pth' # noqa
angle_version = 'le90'
model = dict(
type='mmdet.RTMDet',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False,
boxtype2tensor=False,
batch_augments=None),
backbone=dict(
type='mmdet.CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=1,
widen_factor=1,
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU')华为新网络系列 | 什么是SDN?
[深度学习][原创]旋转目标检测框架yolov5_obb,paddledetection-s2anet和mmrotate谁最好用?