PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)

Posted Xavier Jiezou

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)相关的知识,希望对你有一定的参考价值。

引言

深度学习中,模型的速度和性能具有同等重要的地位,因为这直接关系到模型是否能在实际生产应用中落地。在计算机视觉领域,FPS(模型每秒能够处理的图像帧数)是一个重要且直观地反映模型处理速度的指标,基本在所有图像处理类任务中都有用到,例如图像超分,图像修复和目标检测等等。本文从 MMDetection 中抽取了 FPS Benchmark,并做了微小的修改,以便快速测试。

代码

参数描述
model继承 torch.nn.Module 类实例化的 PyTorch 模型。
input_size模型可接受的输入维度。注意第一个维度是 batch_size,必须为 1,余下的维度根据模型来设置。
device选择在 GPU 或 CPU 上测试 FPS。默认是在 CPU 上测试,也支持 GPU,例如 cuda:0 是在机器的第一张独立显卡上测试。
warmup_num预热次数。因为模型刚开始测试的几轮速度很慢,会影响 FPS 的测试结果,所以我们直接跳过。
log_interval打印日志的频率,即每隔多少轮打印计算的平均 FPS 值。
iterations单次测试的总迭代次数。程序会汇总该迭代次数内的所有 FPS 值,并取平均作为我们最终的结果。
repeat_num重复测试的次数。为进一步缓解测试结果的偶然性,可进行多次重复的测试实验。
import torch
import time


class FPSBenchmark():
    def __init__(
        self,
        model: torch.nn.Module,
        input_size: tuple,
        device: str = "cpu",
        warmup_num: int = 5,
        log_interval: int = 10,
        iterations: int = 100,
        repeat_num: int = 1,
    ) -> None:
        """FPS benchmark.

        Ref:
            MMDetection: https://mmdetection.readthedocs.io/en/stable/useful_tools.html#fps-benchmark.

        Args:
            model (torch.nn.Module): model to be tested.
            input_size (tuple): model acceptable input size, e.g. `BCHW`, make sure `batch_size` is 1.
            device (str): device for test. Default to "cpu".
            warmup_num (int, optional): the first several iterations may be very slow so skip them. Defaults to 5.
            iterations (int, optional): numer of iterations in a single test. Defaults to 100.
            repeat_num (int, optional): number of repeat tests. Defaults to 1.
        """
        # Parameters for `load_model`
        self.model = model
        self.input_size = input_size
        self.device = device

        # Parameters for `measure_inference_speed`
        self.warmup_num = warmup_num
        self.log_interval = log_interval
        self.iterations = iterations

        # Parameters for `repeat_measure_inference_speed`
        self.repeat_num = repeat_num

    def load_model(self):
        model = self.model.to(self.device)
        model.eval()
        return model

    def measure_inference_speed(self):
        model = self.load_model()
        pure_inf_time = 0
        fps = 0

        for i in range(self.iterations):
            input_data = torch.randn(self.input_size, device=self.device)
            if "cuda" in self.device:
                torch.cuda.synchronize()
                start_time = time.perf_counter()
                with torch.no_grad():
                    model(input_data)
                torch.cuda.synchronize()
            elif "cpu" in self.device:
                start_time = time.perf_counter()
                with torch.no_grad():
                    model(input_data)
            else:
                NotImplementedError(
                    f"self.device hasn't been implemented yet."
                )
            elapsed = time.perf_counter() - start_time

            if i >= self.warmup_num:
                pure_inf_time += elapsed
                if (i + 1) % self.log_interval == 0:
                    fps = (i + 1 - self.warmup_num) / pure_inf_time
                    print(
                        f'Done image [i + 1:0>3/self.iterations], '
                        f'FPS: fps:.2f img/s, '
                        f'Times per image: 1000 / fps:.2f ms/img',
                        flush=True,
                    )
                else:
                    pass
            else:
                pass
        fps = (self.iterations - self.warmup_num) / pure_inf_time
        print(
            f'Overall FPS: fps:.2f img/s, '
            f'Times per image: 1000 / fps:.2f ms/img',
            flush=True,
        )
        return fps

    def repeat_measure_inference_speed(self):
        assert self.repeat_num >= 1
        fps_list = []
        for _ in range(self.repeat_num):
            fps_list.append(self.measure_inference_speed())
        if self.repeat_num > 1:
            fps_list_ = [round(fps, 2) for fps in fps_list]
            times_pre_image_list_ = [round(1000 / fps, 2) for fps in fps_list]
            mean_fps_ = sum(fps_list_) / len(fps_list_)
            mean_times_pre_image_ = sum(times_pre_image_list_) / len(
                times_pre_image_list_)
            print(
                f'Overall FPS: fps_list_[mean_fps_:.2f] img/s, '
                f'Times per image: '
                f'times_pre_image_list_[mean_times_pre_image_:.2f] ms/img',
                flush=True,
            )
            return fps_list
        else:
            return fps_list[0]


if __name__ == '__main__':
    FPSBenchmark(
        model=torch.nn.Conv2d(3, 64, 3, 1, 1),
        input_size=(1, 3, 224, 224),
        device="cuda:0",
    ).repeat_measure_inference_speed()

参考

https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/benchmark.py

以上是关于PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Pytorch Lightning 微调之前测试模型?

Pytorch应用:构建分类器

[Pytorch系列-68]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试CycleGAN模型

[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型

[Pytorch系列-67]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型进行测试pix2pix模型

PyTorch 之 基于经典网络架构训练图像分类模型