模型的加载和保存
Posted 1994july
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了模型的加载和保存相关的知识,希望对你有一定的参考价值。
pytorch三种模型的加载保存操作
方法1 : PATH表示保存模型的路径和文件名
torch.save(model, PATH) model = torch.load(PATH) model.eval()
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features=6)
# train your medel...
# save model
FILE = "model.pth"
torch.save(model, FILE)
# load model
model = torch.load(FILE)
# 防止模型参数发生变化
model.eval()
for param in model.parameters():
print(param)
方法二:
保存模型时使用模型的state_dict()方法,加载模型前先实例化一个模型,然后调用load_state_dict()方法
torch.save(model.state_dict(), PATH) # model must be created again with parameters model = Model(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features=6)
# train your medel...
for param in model.parameters():
print(param)
# save model
FILE = "model.pth"
torch.save(model.state_dict(), FILE)
loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
# 防止模型参数发生变化
loaded_model.eval()
for param in loaded_model.parameters():
print(param)
方法三:
定义一个字典,保存多个参数到模型
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features=6)
# train your medel...
# print(model.state_dict())
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# print(optimizer.state_dict())
checkpoint = {
"epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()
}
# 保存三种数据到模型
torch.save(checkpoint, "checkpoint.pth")
# 加载模型
loaded_checkpoint = torch.load("checkpoint.pth")
# 载入epcho数据
epoch = loaded_checkpoint[‘epoch‘]
print(epoch)
# 定义模型和优化器
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)
# 将保存的模型数据载入到模型和优化器中
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])
以上是关于模型的加载和保存的主要内容,如果未能解决你的问题,请参考以下文章