深度学习和目标检测系列教程 10-300:通过torch训练第一个Faster-RCNN模型
Posted 刘润森!
代码的灵感来自此处的 Pytorch 文档教程和Kaggle
- https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
- https://www.kaggle.com/yerramvarun/fine-tuning-faster-rcnn-using-pytorch/
git clone https://github.com/pytorch/vision.git
cp vision/references/detection/utils.py ./
cp vision/references/detection/transforms.py ./
cp vision/references/detection/coco_eval.py ./
cp vision/references/detection/engine.py ./
cp vision/references/detection/coco_utils.py ./
import os
import numpy as np
import cv2
import torch
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from xml.etree import ElementTree as et
from torchvision import transforms as torchtrans
class FruitImagesDataset(torch.utils.data.Dataset):
def __init__(self, files_dir, width, height, transforms=None):
self.transforms = transforms
self.files_dir = files_dir
self.height = height
self.width = width
self.imgs = [image for image in sorted(os.listdir(files_dir))
if image[-4:] == '.jpg']
self.classes = [_,'apple', 'banana', 'orange']
def __getitem__(self, idx):
img_name = self.imgs[idx]
image_path = os.path.join(self.files_dir, img_name)
# reading the images and converting them to correct size and color
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
# diving by 255
img_res /= 255.0
# annotation file
annot_filename = img_name[:-4] + '.xml'
annot_file_path = os.path.join(self.files_dir, annot_filename)
boxes = []
labels = []
tree = et.parse(annot_file_path)
root = tree.getroot()
# cv2 image gives size as height x width
wt = img.shape[1]
ht = img.shape[0]
# box coordinates for xml files are extracted and corrected for image size given
for member in root.findall('object'):
# bounding box
xmin = int(member.find('bndbox').find('xmin').text)
xmax = int(member.find('bndbox').find('xmax').text)
ymin = int(member.find('bndbox').find('ymin').text)
ymax = int(member.find('bndbox').find('ymax').text)
xmin_corr = (xmin / wt) * self.width
xmax_corr = (xmax / wt) * self.width
ymin_corr = (ymin / ht) * self.height
ymax_corr = (ymax / ht) * self.height
boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
# convert boxes into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# getting the areas of the boxes
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["area"] = area
target["iscrowd"] = iscrowd
# image_id
image_id = torch.tensor([idx])
target["image_id"] = image_id
if self.transforms:
sample = self.transforms(image=img_res,
img_res = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])
return img_res, target
def __len__(self):
return len(self.imgs)
def torch_to_pil(img):
return torchtrans.ToPILImage()(img).convert('RGB')
def plot_img_bbox(img, target):
fig, a = plt.subplots(1, 1)
fig.set_size_inches(5, 5)
for box in (target['boxes']):
x, y, width, height = box[0], box[1], box[2] - box[0], box[3] - box[1]
rect = patches.Rectangle((x, y),
width, height,
def get_transform(train):
if train:
return A.Compose([
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
return A.Compose([
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
test_dir = '../input/fruit-images-for-object-detection/test_zip/test'
dataset = FruitImagesDataset(train_dir, 480, 480)
img, target = dataset[78]
print(img.shape, '\\n', target)
plot_img_bbox(torch_to_pil(img), target)
模型导入from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_object_detection_model(num_classes):
# 加载在COCO上预先训练过的模型(会下载对应的权重)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# 获取分类器的输入特征数
in_features = model.roi_heads.box_predictor.cls_score.in_features
# 用新的头替换预先训练好的头
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
对象检测的增强与正常增强不同,因为在这里需要确保 bbox 在转换后仍然正确与对象对齐。
def get_transform(train):
if train:
return A.Compose([
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
return A.Compose([
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
dataset = FruitImagesDataset(files_dir, 480, 480, transforms= get_transform(train=True))
dataset_test = FruitImagesDataset(files_dir, 480, 480, transforms= get_transform(train=False))
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
# train test split
test_split = 0.2
tsize = int(len(dataset)*test_split)
dataset = torch.utils.data.Subset(dataset, indices[:-tsize])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-tsize:])
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=10, shuffle=True, num_workers=4,
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=10, shuffle=False, num_workers=4,
# to train on gpu if selected.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 4
# get the model using our helper function
model = get_object_detection_model(num_classes)
# move model to the right device
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
# training for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
# training for one epoch
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
# update the learning rate
# evaluate on the test dataset
evaluate(model, data_loader_test, device=device)
Torchvision 为我们提供了一个将 nms 应用于我们的预测的实用程序,让我们apply_nms使用它构建一个函数。
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
# function to convert a torchtensor back to PIL image
def torch_to_pil(img):
return torchtrans.ToPILImage()(img).convert('RGB')
# pick one image from the test set
img, target = dataset_test[5]
# put the model in evaluation mode
with torch.no_grad():
prediction = model([img.to(device)])[0]
print('predicted #boxes: ', len(prediction['labels']))
print('real #boxes: ', len(target['labels']))
plot_img_bbox(torch_to_pil(img), target)
plot_img_bbox(torch_to_pil(img), prediction)
你可以看到我们的模型为每个苹果预测了很多边界框。让我们对其应用 nms 并查看最终输出
nms_prediction = apply_nms(prediction, iou_thresh=0.2)
plot_img_bbox(torch_to_pil(img), nms_prediction)
如何微调RCNN模型,并对resnet 50进行微调。如何更改训练配置,比如图像大小、优化器和学习率。如何更好使用Albumentations ,值得去探索。
