Pytorch 模型的存储与加载
Posted wevolf
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 模型的存储与加载相关的知识,希望对你有一定的参考价值。
Pytorch 模型的存储与加载
本文主要内容来自Pytorch官方文档推荐的一篇英文博客, 本文主要介绍了在Pytorch中模型的存储方法, 以及存储形式, 以及Pytorch存储模型正真存储的是模型的什么结构. 以及加载模型的时候, 模型的哪些数据会被加载. 以及加载后的形式.
首先大致讲下三个最主要的函数的功能:
torch.save: 将序列化的对象存储到硬盘中.此函数使用Python的pickle实用程序进行序列化. 对于数据类型都可以进行序列化存储, 模型, 张量, 以及字典, 等各种数据对象都可以使用该函数存储.
torch.load: 该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也促进设备加载数据.
torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典
模型的加载
state_dict 是什么
在一个Pytorch模型中, 通常是 torch.nn.module , 模型中可学习的参数被包含在模型的参数中, 通常是可以使用 model.parameters()
函数访问, 通常都是使用该方法访问的. state_dict只是一个Python字典对象,它将每个图层映射到其参数张量, 这个字典的 key 是图层的 ‘name‘, 注意, 只有该层有可学习的参数的层, 也就是可以通过反向传播优化的层, 以及 registered buffers (batchnorm’s running_mean) 才会在 state_dict 中有存储条目. 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息. state_dict 的本质是对模型进行了字典化.
state_dict的字典形式使得对模型的操作更加的灵活, 例如直接导出模型, 修改其中的参数信息, 或者对层数进行修改等, 然后继续将模型保留. 还是使用一个简单的模型举个例子:
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
# 这里卷积核的大小是 5, 个数是 6, 输入的 width 是 3
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
# 两次卷积的结果应该是 5x5x16 的矩阵
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
# 可以看出网络层的结构, 两个卷积层, 其余还有全连接层
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model‘s state_dict
print("Model‘s state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, " ", model.state_dict()[param_tensor].size())
# Print optimizer‘s state_dict
print("Optimizer‘s state_dict:")
for var_name in optimizer.state_dict():
print(var_name, " ", optimizer.state_dict()[var_name])
可以得到模型的输出为:
Model‘s state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer‘s state_dict:
state {}
param_groups [{‘lr‘: 0.001, ‘momentum‘: 0.9, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False, ‘params‘: [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
模型的参数的输出是字典的键值对, 后面是优化参数的输出, 也是键值对
存储与加载模型对应的形式
使用 state_dict 存储与加载模型
save:
torch.save(model.state_dict(), PATH)
Load 模型:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
从模型存储的角度, 存储模型的时候, 唯一需要存储的是该模型训练的参数, torch.save() 函数也可以存储模型的 state_dict. 使用该方法进行存储, 模型被看做字典形式, 所以对模型的操作更加灵活. 在这种形式下常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型.
注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先 torch.load() 该模型, 然后再使用 load_state_dict().
将模型作为整体存储与加载
Save:
torch.save(model, PATH)
Load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
使用该方法相当于跳过了对模型的 state_dict 描述的过程, 而是直接使用 python 的 pickle 包, 这种方法的缺点是, 模型的存储形式与加载形式十分固定, 这样做的原因是因为pickle不会保存模型类本身. 而是存出来包含该文件的路径,该路径在加载时使用. 因此,在其他项目中使用或重构后,代码可能会以各种方式中断. 但是这种方法存储的文件的类型与前面的方法一样. 同样, 以该方法加载模型运行之前需要调用 model.eval()
.
存储与加载一般的 Checkpoint
Save:
torch.save({
‘epoch‘: epoch,
‘model_state_dict‘: model.state_dict(),
‘optimizer_state_dict‘: optimizer.state_dict(),
‘loss‘: loss,
...
}, PATH)
Load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint[‘model_state_dict‘])
optimizer.load_state_dict(checkpoint[‘optimizer_state_dict‘])
epoch = checkpoint[‘epoch‘]
loss = checkpoint[‘loss‘]
model.eval()
# - or -
model.train()
可以看出 checkpoints 是模型主要内容的一个字典, 基本包含了模型各种数据, 例如上面的例子模型的参数使用的是 optimizer.state_dict().
存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便. 为了存储一个训练过程的多种信息, 最好的方式是使用 dictionary 进行序列化, 这样存储一个训练模型的形式是 .tar, 要加载项目,首先初始化模型和优化器,然后使用torch.load() 在本地加载字典.从这里开始, 只需按期望查询字典即可轻松访问已保存的项目. 请记住,在运行推理之前,必须调用model.eval() 来将 Dropout 和 Batch 正则化设置为评估模式, 不这样做将产生不一致的推断结果. 如果恢复训练,那么调用model.train() 以确保这些层处于训练模式.
在一个文件中存储多个模型
save:
torch.save({
‘modelA_state_dict‘: modelA.state_dict(),
‘modelB_state_dict‘: modelB.state_dict(),
‘optimizerA_state_dict‘: optimizerA.state_dict(),
‘optimizerB_state_dict‘: optimizerB.state_dict(),
...
}, PATH)
Load:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint[‘modelA_state_dict‘])
modelB.load_state_dict(checkpoint[‘modelB_state_dict‘])
optimizerA.load_state_dict(checkpoint[‘optimizerA_state_dict‘])
optimizerB.load_state_dict(checkpoint[‘optimizerB_state_dict‘])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
保存包含多个 torch.nn.Modules 的模型(例如GAN,序列到序列模型或模型集合)时,将采用与保存常规检查点相同的方法。 换句话说,保存每个模型的state_dict和相应的优化器的字典. 如前所述,您可以保存任何其他可以帮助您恢复培训的项目,只需将它们添加到字典中即可. 使用该方法存储的文件也是 .tar 形式的, 要加载模型,请首先初始化模型和优化器,然后使用torch.load()在本地加载字典。 从这里,您只需按期望查询字典即可轻松访问已保存的项目.
跨平台模型保存与加载
GPU 到 CPU
Save:
torch.save(model.state_dict(), PATH)
Load:
device = torch.device(‘cpu‘)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
Save on GPU, Load on GPU
Save:
torch.save(model.state_dict(), PATH)
Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
Save on CPU, Load on GPU
Save:
torch.save(model.state_dict(), PATH)
Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
以上是关于Pytorch 模型的存储与加载的主要内容,如果未能解决你的问题,请参考以下文章