[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121184678
目录
第1章 torchvision与预训练模型的自动下载
第2章 预训练模型的手工下载
第3章 网络介绍
第4章 前置条件:系统库的导入
import torch # torch基础库
import torchvision.models as models # torchvision模型库
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
第5章 预训练模型的导入
5.1 模型的创建
## 创建模型
net = models.resnet101()
print(net)
5.2 模型参数的导入
##导入模型参数
net_params_path = "models/resnet101.pth"
net_params = torch.load(model_params_path)
print(net_params)
5.3 模型参数的应用
# 把加载的参数应用到模型中
net.load_state_dict(net_params)
print(net)
5.4 模型的简单测试
(1)测试1
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print("input shape = ", input.shape)
output = net(input)
print("output shape = ", output.shape)
定义测试数据 input shape = torch.Size([1, 3, 224, 224]) output shape = torch.Size([1, 1000])
(2)测试2:
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print("input shape = ", input.shape)
output = net(input)
print("output shape = ", output.shape)
定义测试数据 input shape = torch.Size([1, 3, 224, 224]) output shape = torch.Size([1, 1000])
此时,可以使用该模型对图片进行预测了!!!
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121184678
以上是关于[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)的主要内容,如果未能解决你的问题,请参考以下文章