PyTorch学习网络的保存与提取

Posted My heart will go ~~

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习网络的保存与提取相关的知识,希望对你有一定的参考价值。

保存提取神经网络
神经网络一般保存为.pkl文件,代码非常简单。

保存:torch.save(net1,‘net.pkl’) #将net1保存为net.pkl
提取:net2=torch.load(‘net.pkl’)#提取网络到net2

第二种方法
保存 torch.save(net1.state_dict(),‘net_params.pkl’)#保存参数
提取,首先要先定义一个与之前的一样的网络
net3=torch.nn.Sequential(#先建立与net1一样的神经网络
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load(‘net_params.pkl’))

import torch
import matplotlib.pyplot as plt

#torch.manual_seed(1)

# 数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())


def save():
    #快速搭建的神经网络
    net1=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    
    #进行训练所有参数
    optimizer=torch.optim.SGD(net1.parameters(), lr=0.5)#训练所有参数
    loss_func=torch.nn.MSELoss()
    
    for t in range(100):
        prediction=net1(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    
 
    #训练好了以后,保存神经网络
    torch.save(net1,'net.pkl')#保存整个网络
    torch.save(net1.state_dict(),'net_params.pkl')#保存参数
  
        
#提取方法1
def restore_net():
    net2=torch.load('net.pkl')#提取网络到net2
    prediction=net2(x)
    #plot
    plt.subplot(132)
    plt.title('net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

#提取方法2
def restore_params():
    net3=torch.nn.Sequential(#先建立与net1一样的神经网络
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction=net3(x)
    
    plt.subplot(133)
    plt.title('net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5) 
    plt.show()
    
    
    
save()#save的是net1
restore_net()
restore_params()

结果为三张图都一样

在这里插入图片描述

以上是关于PyTorch学习网络的保存与提取的主要内容,如果未能解决你的问题,请参考以下文章

pytorch学习-4:快速搭建+保存提取

pytorch学习-4:快速搭建+保存提取

从单个按钮从多个片段中提取数据

[Pytorch]Pytorch 保存模型与加载模型(转)

Pytorch 网络模型的保存与读取

基于VGG19神经网络的提取特征 进行 可见光与红外光的 图像融合 基于pytorch 实现。。。