PyTorch学习网络的保存与提取
Posted My heart will go ~~
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习网络的保存与提取相关的知识,希望对你有一定的参考价值。
保存提取神经网络
神经网络一般保存为.pkl文件,代码非常简单。
保存:torch.save(net1,‘net.pkl’) #将net1保存为net.pkl
提取:net2=torch.load(‘net.pkl’)#提取网络到net2
第二种方法
保存 torch.save(net1.state_dict(),‘net_params.pkl’)#保存参数
提取,首先要先定义一个与之前的一样的网络
net3=torch.nn.Sequential(#先建立与net1一样的神经网络
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load(‘net_params.pkl’))
import torch
import matplotlib.pyplot as plt
#torch.manual_seed(1)
# 数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
def save():
#快速搭建的神经网络
net1=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
#进行训练所有参数
optimizer=torch.optim.SGD(net1.parameters(), lr=0.5)#训练所有参数
loss_func=torch.nn.MSELoss()
for t in range(100):
prediction=net1(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#训练好了以后,保存神经网络
torch.save(net1,'net.pkl')#保存整个网络
torch.save(net1.state_dict(),'net_params.pkl')#保存参数
#提取方法1
def restore_net():
net2=torch.load('net.pkl')#提取网络到net2
prediction=net2(x)
#plot
plt.subplot(132)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#提取方法2
def restore_params():
net3=torch.nn.Sequential(#先建立与net1一样的神经网络
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction=net3(x)
plt.subplot(133)
plt.title('net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
plt.show()
save()#save的是net1
restore_net()
restore_params()
结果为三张图都一样
以上是关于PyTorch学习网络的保存与提取的主要内容,如果未能解决你的问题,请参考以下文章