PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)
Posted Xavier Jiezou
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)相关的知识,希望对你有一定的参考价值。
引言
深度学习中,模型的速度和性能具有同等重要的地位,因为这直接关系到模型是否能在实际生产应用中落地。在计算机视觉领域,FPS(模型每秒能够处理的图像帧数)是一个重要且直观地反映模型处理速度的指标,基本在所有图像处理类任务中都有用到,例如图像超分,图像修复和目标检测等等。本文从 MMDetection 中抽取了 FPS Benchmark,并做了微小的修改,以便快速测试。
代码
参数 | 描述 |
---|---|
model | 继承 torch.nn.Module 类实例化的 PyTorch 模型。 |
input_size | 模型可接受的输入维度。注意第一个维度是 batch_size ,必须为 1,余下的维度根据模型来设置。 |
device | 选择在 GPU 或 CPU 上测试 FPS。默认是在 CPU 上测试,也支持 GPU,例如 cuda:0 是在机器的第一张独立显卡上测试。 |
warmup_num | 预热次数。因为模型刚开始测试的几轮速度很慢,会影响 FPS 的测试结果,所以我们直接跳过。 |
log_interval | 打印日志的频率,即每隔多少轮打印计算的平均 FPS 值。 |
iterations | 单次测试的总迭代次数。程序会汇总该迭代次数内的所有 FPS 值,并取平均作为我们最终的结果。 |
repeat_num | 重复测试的次数。为进一步缓解测试结果的偶然性,可进行多次重复的测试实验。 |
import torch
import time
class FPSBenchmark():
def __init__(
self,
model: torch.nn.Module,
input_size: tuple,
device: str = "cpu",
warmup_num: int = 5,
log_interval: int = 10,
iterations: int = 100,
repeat_num: int = 1,
) -> None:
"""FPS benchmark.
Ref:
MMDetection: https://mmdetection.readthedocs.io/en/stable/useful_tools.html#fps-benchmark.
Args:
model (torch.nn.Module): model to be tested.
input_size (tuple): model acceptable input size, e.g. `BCHW`, make sure `batch_size` is 1.
device (str): device for test. Default to "cpu".
warmup_num (int, optional): the first several iterations may be very slow so skip them. Defaults to 5.
iterations (int, optional): numer of iterations in a single test. Defaults to 100.
repeat_num (int, optional): number of repeat tests. Defaults to 1.
"""
# Parameters for `load_model`
self.model = model
self.input_size = input_size
self.device = device
# Parameters for `measure_inference_speed`
self.warmup_num = warmup_num
self.log_interval = log_interval
self.iterations = iterations
# Parameters for `repeat_measure_inference_speed`
self.repeat_num = repeat_num
def load_model(self):
model = self.model.to(self.device)
model.eval()
return model
def measure_inference_speed(self):
model = self.load_model()
pure_inf_time = 0
fps = 0
for i in range(self.iterations):
input_data = torch.randn(self.input_size, device=self.device)
if "cuda" in self.device:
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(input_data)
torch.cuda.synchronize()
elif "cpu" in self.device:
start_time = time.perf_counter()
with torch.no_grad():
model(input_data)
else:
NotImplementedError(
f"self.device hasn't been implemented yet."
)
elapsed = time.perf_counter() - start_time
if i >= self.warmup_num:
pure_inf_time += elapsed
if (i + 1) % self.log_interval == 0:
fps = (i + 1 - self.warmup_num) / pure_inf_time
print(
f'Done image [i + 1:0>3/self.iterations], '
f'FPS: fps:.2f img/s, '
f'Times per image: 1000 / fps:.2f ms/img',
flush=True,
)
else:
pass
else:
pass
fps = (self.iterations - self.warmup_num) / pure_inf_time
print(
f'Overall FPS: fps:.2f img/s, '
f'Times per image: 1000 / fps:.2f ms/img',
flush=True,
)
return fps
def repeat_measure_inference_speed(self):
assert self.repeat_num >= 1
fps_list = []
for _ in range(self.repeat_num):
fps_list.append(self.measure_inference_speed())
if self.repeat_num > 1:
fps_list_ = [round(fps, 2) for fps in fps_list]
times_pre_image_list_ = [round(1000 / fps, 2) for fps in fps_list]
mean_fps_ = sum(fps_list_) / len(fps_list_)
mean_times_pre_image_ = sum(times_pre_image_list_) / len(
times_pre_image_list_)
print(
f'Overall FPS: fps_list_[mean_fps_:.2f] img/s, '
f'Times per image: '
f'times_pre_image_list_[mean_times_pre_image_:.2f] ms/img',
flush=True,
)
return fps_list
else:
return fps_list[0]
if __name__ == '__main__':
FPSBenchmark(
model=torch.nn.Conv2d(3, 64, 3, 1, 1),
input_size=(1, 3, 224, 224),
device="cuda:0",
).repeat_measure_inference_speed()
参考
https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/benchmark.py
以上是关于PyTorch模型 FPS 测试 Benchmark(参考 MMDetection 实现)的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Pytorch Lightning 微调之前测试模型?
[Pytorch系列-68]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试CycleGAN模型
[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型
[Pytorch系列-67]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型进行测试pix2pix模型