PyTorch参数模型转换为PT模型

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch参数模型转换为PT模型相关的知识,希望对你有一定的参考价值。

当PyTorch模型需要部署到服务时,为了提升访问速度,需要转换为TRT模型,再进行部署。在转换为TRT模型之前,需要将PyTorch参数模型(如pth.tar)转换为pt模型,使用jit形式。pt模型 = 参数模型(pth.tar) + 网络结构(如resnet50)。使用pt模型,可以简化使用方式,同时也方便转换为trt模型,进行轻量级部署。在转换函数中,包含验证逻辑,保证转换前后的模型效果一致,即输出不变。

以图像分类框架pytorch-image-models-my为例,将PyTorch的pth.tar模型转换为PT模型。

转换流程如下:

  1. 加载pth.tar模型model,model达到可以预测的标准,即:
# 加载模型
model = timm.create_model(model_name=base_net, pretrained=False,
                          checkpoint_path=model_path, num_classes=num_classes)
if torch.cuda.is_available():
    print('[Info] cuda on!!!')
    model = model.cuda()
model.eval()

# 预测结果
print('[Info] 预测图像尺寸: {}'.format(img_rgb.shape))
img_tensor = self.preprocess_img(img_rgb, self.transform)
print('[Info] 模型输入: {}'.format(img_tensor.shape))
with torch.no_grad():
    out = self.model(img_tensor)
  1. 将已加载的模型model,通过torch.jit.trace()模拟输入dummy_input,调用traced.save()存储成pt模型,即:

    • 注意输入尺寸dummy_shape,用于生成模拟的input数据,需要与模型输入保持一致

    • 注意是否支持GPU,即orch.cuda.is_available(),判断环境是cuda还是cpu。

dummy_shape = (1, 3, 336, 336)  # 不影响模型
print('[Info] dummy_shape: {}'.format(dummy_shape))
if torch.cuda.is_available():
    model_type = "cuda"
else:
    model_type = "cpu"
print('[Info] model_type: {}'.format(model_type))
dummy_input = torch.empty(dummy_shape,
                          dtype=torch.float32,
                          device=torch.device(model_type))
traced = torch.jit.trace(self.model, dummy_input)
pt_path = os.path.join(pt_folder_path, "{}_{}.pt".format(model_name, model_type))
traced.save(pt_path)
  1. 验证pt模型是否与原模型pth.tar的输出是否一致,pt模型调用reload_script(),即:
with torch.no_grad():
    standard_out = self.model(dummy_input)
print('[Info] standard_out: {}'.format(standard_out))

reload_script = torch.jit.load(pt_path)
with torch.no_grad():
    script_output = reload_script(dummy_input)
print('[Info] script_output: {}'.format(script_output))
print('[Info] 验证 is equal: {}'.format(F.l1_loss(standard_out, script_output)))

print('[Info] 存储完成: {}'.format(pt_path))

全部转换和验证PT模型的逻辑,都位于save_pt()函数中,调用即可生成,输出位于pt_models文件夹中,即:

me.save_pt(os.path.join(DATA_DIR, "pt_models"))

输出的模型是:model_best_c2_20210915_cpu.pt,GPU版本是:model_best_c2_20210915_cuda.pt

在pytorch-image-models-my工程中,pth.tar模型转换为PT模型的转换脚本,源码如下,参考model_2_pt_script.py

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 15.9.21
"""
import argparse
import os
import sys

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)

from root_dir import DATA_DIR
from myscripts.img_predictor import ImgPredictor


def parse_args():
    """
    处理脚本参数
    """
    parser = argparse.ArgumentParser(description='PyTorch模型转换PT模型')
    parser.add_argument('-m', dest='model_path', required=True, help='模型路径', type=str)
    parser.add_argument('-n', dest='base_net', required=False, help='basenet', type=str, default="resnet50")
    parser.add_argument('-c', dest='num_classes', required=False, help='类别个数', type=int, default=2)
    parser.add_argument('-o', dest='out_dir', required=False, help='输出文件夹', type=str,
                        default=os.path.join(DATA_DIR, "pt_models"))

    args = parser.parse_args()

    arg_model_path = args.model_path
    print("[Info] 模型路径: {}".format(arg_model_path))

    arg_base_net = args.base_net
    print("[Info] basenet: {}".format(arg_base_net))

    arg_num_classes = args.num_classes
    print("[Info] 类别数: {}".format(arg_num_classes))

    arg_out_dir = args.out_dir
    print("[Info] 输出文件夹: {}".format(arg_out_dir))

    return arg_model_path, arg_base_net, arg_num_classes, arg_out_dir


def main():
    """
    入口函数
    """
    print('[Info] ' + "-" * 100)
    print('[Info] 转换PT模型开始')
    arg_model_path, arg_base_net, arg_num_classes, arg_out_dir = parse_args()
    me = ImgPredictor(arg_model_path, arg_base_net, arg_num_classes)
    pt_path = me.save_pt(arg_out_dir)  # 存储PT模型
    print('[Info] 存储完成: {}'.format(pt_path))
    print('[Info] ' + "-" * 100)


if __name__ == "__main__":
    main()

以上是关于PyTorch参数模型转换为PT模型的主要内容,如果未能解决你的问题,请参考以下文章

YOLOV3——PyTorch训练TensorFlowLite部署模型转换

[Pytorch].pth转.pt文件

Pytorch模型转Android端模型

Pytorch模型转Android端模型

pytorch自动删除之前保存的pt文件

pytorch模型转换为rknn模型,使用npu推理