使用pytorch保存效果最好那个模型+加载模型
Posted 无脑敲代码,bug漫天飞
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用pytorch保存效果最好那个模型+加载模型相关的知识,希望对你有一定的参考价值。
1 保存在验证集上表现最好的那一轮模型
1 验证集的作用就是监督训练是否过拟合;
一般默认验证集的损失值经历由下降到上升的阶段;
保存在验证集上损失最小的那个迭代模型,其泛化能力应该最好;
# 在训练部分计算验证集损失值,保存最小损失值对应的那个模型
model = BotRGCN()# 自定义模型实例化,()中可以传定义的参数
def train(epoch,min_loss):
model.train()
output = model() # 自动调用定义的forward函数,在()中传相应参数
loss_train = loss(output[et.train_idx],de.labels[et.train_idx])
acc_train = accuracy(output[et.train_idx],de.labels[et.train_idx])
acc_val = accuracy(output[et.val_idx],de.labels[et.val_idx])
# 计算损失值,做比较
loss_val = loss(output[et.val_idx],de.labels[et.val_idx])
optimizer.zero_grad()
loss_train.backgrad()
optimizer.step()
if loss_val < min_loss
min_loss = loss_val
print("save model")
# 保存模型语句
torch.save(model.state_dict(),"model.pth")
return loss_train, acc_train, acc_val, min_loss
if __name__ == "__main__":
epochs = 100
min_loss = 100
for epoch in range(epochs):
loss_train, acc_train, acc_val, min_loss = train(epoch,min_loss)
保存模型中state_dict 是状态字典;
PyTorch 中,一个模型( torch.nn.Module
)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters()
)中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。
模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm
的 running_mean
)。优化器对象(torch.optim
)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数.
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save()
来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
2 加载模型,在测试集上测试模型效果
model = BotRGCN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
test()
在进行预测之前,必须调用 model.eval()
方法来将 dropout
和 batch normalization
层设置为验证模型。否则,只会生成前后不一致的预测结果。
load_state_dict()
方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load()
,而不是直接 model.load_state_dict(PATH)
3 另一种保存与加载方法
加载保存整个模型
保存:
torch.save(model, 'model.pkl')
加载:
# Model class must be defined somewhere
model = torch.load('model.pkl')
model.eval()
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现;
这种实现保存模型的做法将是采用 Python 的 pickle
模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle
并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors
后采用都可能出现错误。
PyTorch保存和加载模型
在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式:
# 方式一:保存模型的结果信息和参数信息 torch.save(model, ‘./model.pth‘) # 方式二:仅保存模型的参数信息 torch.save(model.state_dict(), ‘./model_state.pth‘)
相应的,有两种加载模型的方式:
# 方式一:加载完整的模型结构和参数信息,在网络较大时加载时间比较长,同时存储空间也比较大 model1= torch.load(‘model.pth‘) # 方式二:需先搭建网络模型model2,然后通过下面的语句加载参数 model2.load_state_dic(torch.load(‘model_state.pth‘))
以上是关于使用pytorch保存效果最好那个模型+加载模型的主要内容,如果未能解决你的问题,请参考以下文章