深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强

Posted 刘润森!

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强相关的知识,希望对你有一定的参考价值。

@Author:Runsen

上次对xml文件进行提取,使用到一个Albumentation模块。Albumentation模块是一个数据增强的工具,目标检测图像预处理通过使用“albumentation”来应用的,这是一个易于与PyTorch数据转换集成的python库。

Albumentation 是一种工具,可以在将(图像/图片)插入模型之前自定义 处理(弹性、网格、运动模糊、移位、缩放、旋转、转置、对比度、亮度等])到图像/图片。

对此,Albumentation 官方文档:

  • https://albumentations.ai/

为什么要看看这个东西?因为将 Torchvision 代码重构为 Albumentation 的效果最好,运行更快。

上图是使用 Intel Xeon Platinum 8168 CPUImageNet中通过 2000 个验证集图像的测试结果。每个单元格中的值表示在单个核心中处理的图像数量。可以看到 Albumentation在许多转换方面比所有其他库至少高出 2 倍。

Albumentation Github 的官方 CPU 基准测试https://github.com/albumentations-team/albumentations

下面,我导入了下面的模块:

from PIL import Image
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import cv2
import numpy as np

为了演示的目的,我找了一张前几天毕业回校拍的照片

原始 TorchVision 数据管道

创建一个 Dataloader 来使用 PyTorch 和 Torchvision 处理图像数据管道。

  • 创建一个简单的 Pytorch 数据集类
  • 调用图像并进行转换
  • 用 100 个循环测量整个处理时间

首先,从torch.utils.data获取 Dataset抽象类,并创建一个 TorchVision数据集类。然后我插入图像并使用__getitem__方法进行转换。另外,我用来total_time = (time.time() - start_t测量需要多长时间

class TorchvisionDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
        
        image = Image.open(file_path)
        
        start_t = time.time()
        if self.transform:
            image = self.transform(image)
        total_time = (time.time() - start_t)

        return image, label, total_time

然后将图像大小调整为 256x256(高度 * 重量)并随机裁剪到 224x224 大小。然后以 50% 的概率应用水平翻转并将其转换为张量。输入文件路径应该是您的图像所在的 Google Drive 的路径。

torchvision_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

torchvision_dataset = TorchvisionDataset(
    file_paths=["demo.jpg"],
    labels=[1],
    transform=torchvision_transform,
)

下面计算从 torchvision_dataset 中提取样本图像并对其进行转换所花费的时间,然后运行 ​​100 次循环以检查它所花费的平均毫秒。

torchvision time/sample: 7.31137752532959 ms

在torch中的GPU,原始 TorchVision 数据管道数据预处理的速度大约是0.0731137752532959 ms。最后输出的图像则为 224x224而且发生了翻转!

Albumentation 数据管道

现在创建了一个 Albumentations Dataset 类,具体的transform和原始 TorchVision 数据管道完全一样。

from PIL import Image
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import cv2
import numpy as np


class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with OpenCV
        image = cv2.imread(file_path)

        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        start_t = time.time()
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            total_time = (time.time() - start_t)
        return image, label, total_time

albumentations_transform = albumentations.Compose([
    albumentations.Resize(256, 256),
    albumentations.RandomCrop(224, 224),
    albumentations.HorizontalFlip(), # Same with transforms.RandomHorizontalFlip()
    albumentations.pytorch.transforms.ToTensor()
])
albumentations_dataset = AlbumentationsDataset(
    file_paths=["demo.jpg"],
    labels=[1],
    transform=albumentations_transform,
)


total_time = 0
for i in range(100):
  sample, _, transform_time = albumentations_dataset[0]
  total_time += transform_time

print("albumentations time/sample: {} ms".format(total_time*10))

plt.figure(figsize=(10, 10))
plt.imshow(transforms.ToPILImage()(sample))
plt.show()

具体输出如下:

albumentations time/sample: 0.5056881904602051 ms

在torch中的GPU,Albumentation 数据管道 数据管道数据预处理的速度大约是0.005056881904602051 ms。

因此,在真正的工业落地,基本需要将原始 TorchVision 数据管道改写成Albumentation 数据管道,因为落地项目的速度很重要。

Albumentation数据增强

最后,我将展示如何使用albumentations中OneOf函数进行书增强,我个人觉得这个函数在 Albumentation 中非常有用。

from PIL import Image
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import cv2

class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

# OneOf随机采用括号内列出的变换之一。
# 我们甚至可以将发生的概率放在函数本身中。例如,如果 ([…], p=0.5) 之一,它会以 50% 的机会跳过整个变换,并以 1/6 的机会随机选择三个变换之一。
albumentations_transform_oneof = albumentations.Compose([
    albumentations.Resize(256, 256),
    albumentations.RandomCrop(224, 224),
    albumentations.OneOf([albumentations.HorizontalFlip(p=1),albumentations.RandomRotate90(p=1),albumentations.VerticalFlip(p=1)], p=1),
    albumentations.OneOf([albumentations.MotionBlur(p=1),albumentations.OpticalDistortion(p=1), albumentations.GaussNoise(p=1)], p=1),
    albumentations.pytorch.ToTensor()
])


albumentations_dataset = AlbumentationsDataset(
    file_paths=["demo.jpg"],
    labels=[1],
    transform=albumentations_transform_oneof,
)


num_samples = 5
fig, ax = plt.subplots(1, num_samples, figsize=(25, 5))
for i in range(num_samples):
  ax[i].imshow(transforms.ToPILImage()(albumentations_dataset[0][0]))
  ax[i].axis('off')

plt.show()


上面的OneOf是在水平翻转、旋转、垂直翻转中随机选择,在模糊、失真、噪声中随机选择。所以在这种情况下,我们允许 3x3 = 9 种组合

以上是关于深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强的主要内容,如果未能解决你的问题,请参考以下文章

深度学习和目标检测系列教程 1-300:什么是对象检测和常见的8 种基础目标检测算法

深度学习和目标检测系列教程 3-300:了解常见的目标检测的开源数据集

深度学习和目标检测系列教程 2-300:小试牛刀,使用 ImageAI 进行对象检测

深度学习和目标检测系列教程 19-300:关于目标检测APIoU和mAP简介

深度学习和目标检测系列教程 19-300:关于目标检测APIoU和mAP简介

深度学习和目标检测系列教程 5-300:早期的目标检测RCNN架构