best_state 在 pytorch 训练期间随模型而变化
Posted
技术标签:
【中文标题】best_state 在 pytorch 训练期间随模型而变化【英文标题】:best_state changes with the model during training in pytorch 【发布时间】:2019-10-24 20:56:08 【问题描述】:我想保存最佳模型,然后在测试期间加载它。所以我使用了以下方法:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
然后,在我使用的主函数中:
model.load_state_dict(best_state)
恢复模型。
但是,我发现best_state总是和训练时的最后一个状态一样,而不是最好的状态。有人知道原因以及如何避免吗?
顺便说一句,我知道我可以使用torch.save(the_model.state_dict(), PATH)
,然后通过
the_model.load_state_dict(torch.load(PATH))
。
但是,我不想将参数保存到文件中,因为训练和测试函数在一个文件中。
【问题讨论】:
版本 1.1.0, linux, GPU 【参考方案1】:model.state_dict()
是OrderedDict
from collections import OrderedDict
你可以使用:
from copy import deepcopy
解决问题
改为:
best_state = model.state_dict()
你应该使用:
best_state = copy.deepcopy(model.state_dict())
深(非浅)副本使可变的 OrderedDict 实例不会随其发生变化 best_state
。
您可以查看我的other answer 在 PyTorch 中保存状态字典。
【讨论】:
【参考方案2】:当你保存模型的状态时,你应该在网络中保存以下内容
1) 优化器状态和 2) 模型的状态字典
您可以在类模型中定义一种方法,如下所示
def save_state(state,filename):
torch.save(state,filename)
''' 当您保存状态时,请执行以下操作: '''
Model model //for example
model.save_state('state_dict':model.state_dict(), 'optimizer': optimizer.state_dict())
保存的模型将存储为model.pth.tar(例如)
现在在加载过程中执行以下步骤,
checkpoint = torch.load('model.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
希望这会对你有所帮助。
【讨论】:
这个答案很适合this question,但不完全适合这个。 谢谢,我知道我可以将状态保存到文件中。但我更喜欢 Saurav 的回答 @Lei_Bai 很高兴能帮到你一个小忙。希望我将来也能为您提供帮助。以上是关于best_state 在 pytorch 训练期间随模型而变化的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow 在训练期间没有使用我的 M1 MacBook GPU