每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save
Posted cv.exp
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save相关的知识,希望对你有一定的参考价值。
今天我们讲解:
state_dict = torch.load('checkpoint.pt')
#或者
state_dict = torch.load('checkpoint.pth') #torch.load加载**模型参数**
model.load_state_dict(state_dict) #把模型参数加载到模型中
model.cuda()
model.eval() #model.eval()关闭Batch Normalization和Dropout层
#加载模型结构和模型参数
model = torch.load(path)
output = model(x)
torch.save(model.state_dict(), ‘checkpoint.pt’) #仅保存模型参数
torch.save(model,‘checkpoint.pt’) #保存模型结构和模型参数
以上是关于每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save的主要内容,如果未能解决你的问题,请参考以下文章