PyTorch:模型save和load
Posted -柚子皮-
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch:模型save和load相关的知识,希望对你有一定的参考价值。
神经网络训练后我们需要将模型进行保存,要用的时候将保存的模型进行加载。
PyTorch 中保存模型主要分为两类:保存整个模型和只保存模型参数。
A common PyTorch convention is to save models using either a .pt
or .pth
file extension.
保存加载整个模型(不推荐)
保存整个网络模型
(网络结构+权重参数)
torch.save(model, 'net.pth')
出错:
AttributeError: Can't pickle local object 'AtomicModel.get_metrics.<locals>.<lambda>'
AttributeError: Can't pickle local object 'AtomicModel._get_metrics.<locals>._accuracy_score'
原因:pickle不能序列化lambda函数,或者是闭包。[python模块 - pickle模块]
加载整个网络模型
(可能比较耗时)
model = torch.load('net.pth')
只保存加载模型参数(推荐)
保存模型的权重参数
(速度快,占内存少)
torch.save(model.state_dict(), 'net_params.pth')
load模型参数
因为我们只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。
# 构建一个网络结构
model = ClassNet()
# 将模型参数加载到新模型中,torch.load 返回的是一个 OrderedDict,说明.state_dict()
只是把所有模型的参数都以OrderedDict
的形式存下来。
state_dict = torch.load('net_params.pth')
model.load_state_dict(state_dict)
Note: 保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。
load_state_dict的参数Strict=False
new_model.load_state_dict(state_dict, strict=False)
如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key。但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。
[Pytorch学习(十七)--- 模型load各种问题解决]
保存加载自定义模型
上面“保存加载整个模型”加载的 net.pt 其实一个字典,通常包含如下内容:
网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict() 方法来获取,比如只保存模型权重参数时用到的 model.state_dict()。
优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
其他信息:有时我们需要保存一些其他的信息,比如 epoch,batch_size 等超参数。
我们可以自定义需要save的内容
# saving a checkpoint assuming the network class named ClassNet
checkpoint = 'model': ClassNet(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
torch.save(checkpoint, 'checkpoint.pkl')
上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。
然后我们要load上面保存的自定义的模型
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = checkpoint['model'] # 提取网络结构
model.load_state_dict(checkpoint['model_state_dict']) # 加载网络权重参数
optimizer = TheOptimizerClass()
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数
for parameter in model.parameters():
parameter.requires_grad = False
model.eval()
return model
model = load_checkpoint('checkpoint.pkl')
后续使用
如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropout 和 batch normalization 层进行固定,否则模型的预测结果每次都会不同。
如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。
跨设备保存加载模型
在 CPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on CPU):
device = torch.device('cpu')
model = TheModelClass()
# Load all tensors onto the CPU device
model.load_state_dict(torch.load('net_params.pkl', map_location=device))
map_location:a function, torch.device, string or a dict specifying how to remap storage locations
令 torch.load() 函数的 map_location 参数等于 torch.device('cpu') 即可。 这里令 map_location 参数等于 'cpu' 也同样可以。
from: -柚子皮-
ref: [SAVING AND LOADING MODELS]
以上是关于PyTorch:模型save和load的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch中通过torch.save保存模型和torch.load加载模型介绍
Pytorch model saving and loading 模型保存和读取
每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save
每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save