best_state 在 pytorch 训练期间随模型而变化

Posted

技术标签:

【中文标题】best_state 在 pytorch 训练期间随模型而变化【英文标题】:best_state changes with the model during training in pytorch 【发布时间】:2019-10-24 20:56:08 【问题描述】:

我想保存最佳模型,然后在测试期间加载它。所以我使用了以下方法:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  

然后,在我使用的主函数中:

model.load_state_dict(best_state)  

恢复模型。

但是,我发现best_state总是和训练时的最后一个状态一样,而不是最好的状态。有人知道原因以及如何避免吗?

顺便说一句,我知道我可以使用torch.save(the_model.state_dict(), PATH),然后通过 the_model.load_state_dict(torch.load(PATH))。 但是,我不想将参数保存到文件中,因为训练和测试函数在一个文件中。

【问题讨论】:

版本 1.1.0, linux, GPU 【参考方案1】:

model.state_dict()OrderedDict

from collections import OrderedDict

你可以使用:

from copy import deepcopy

解决问题

改为:

best_state = model.state_dict() 

你应该使用:

best_state = copy.deepcopy(model.state_dict())

深(非浅)副本使可变的 OrderedDict 实例不会随其发生变化 best_state

您可以查看我的other answer 在 PyTorch 中保存状态字典。

【讨论】:

【参考方案2】:

当你保存模型的状态时,你应该在网络中保存以下内容

1) 优化器状态和 2) 模型的状态字典

您可以在类模型中定义一种方法,如下所示

def save_state(state,filename):
    torch.save(state,filename)

''' 当您保存状态时,请执行以下操作: '''

Model model //for example  
model.save_state('state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()) 

保存的模型将存储为model.pth.tar(例如)

现在在加载过程中执行以下步骤,

checkpoint = torch.load('model.pth.tar')         

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

希望这会对你有所帮助。

【讨论】:

这个答案很适合this question,但不完全适合这个。 谢谢,我知道我可以将状态保存到文件中。但我更喜欢 Saurav 的回答 @Lei_Bai 很高兴能帮到你一个小忙。希望我将来也能为您提供帮助。

以上是关于best_state 在 pytorch 训练期间随模型而变化的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 在训练期间没有使用我的 M1 MacBook GPU

PyTorch 中的高效指标评估

Pytorch CNN 不学习

在PyTorch中构建高效的自定义数据集

pytorch - 如何从 DistributedDataParallel 学习中保存和加载模型

pytorch Dropout:“通道将独立清零”