模型的加载和保存

Posted 1994july

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了模型的加载和保存相关的知识,希望对你有一定的参考价值。

pytorch三种模型的加载保存操作

方法1 : PATH表示保存模型的路径和文件名

torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...

# save model
FILE = "model.pth"
torch.save(model, FILE)

# load model
model = torch.load(FILE)

# 防止模型参数发生变化
model.eval()
for param in model.parameters():
    print(param)

方法二:

保存模型时使用模型的state_dict()方法,加载模型前先实例化一个模型,然后调用load_state_dict()方法

torch.save(model.state_dict(), PATH)
# model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...


for param in model.parameters():
    print(param)

# save model
FILE = "model.pth"
torch.save(model.state_dict(), FILE)

loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))

# 防止模型参数发生变化
loaded_model.eval()
for param in loaded_model.parameters():
    print(param)

方法三:

定义一个字典,保存多个参数到模型

class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...

# print(model.state_dict())


learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# print(optimizer.state_dict())


checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict()
}
# 保存三种数据到模型
torch.save(checkpoint, "checkpoint.pth")

# 加载模型
loaded_checkpoint = torch.load("checkpoint.pth")
# 载入epcho数据
epoch = loaded_checkpoint[‘epoch‘]
print(epoch)

# 定义模型和优化器
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)


# 将保存的模型数据载入到模型和优化器中
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])

 推荐:什么是顺时针,你看到的是顺时针还是逆时针

以上是关于模型的加载和保存的主要内容,如果未能解决你的问题,请参考以下文章

如何保存和加载 Android 活动的预设?

保存和加载模型

在 R 中保存和加载模型

敲除加载和保存视图模型

保存片段状态操作栏选项卡

使用 BottomBar 和片段容器禁用 Android 片段重新加载