pytorch之 sava_reload_model

Posted dhname

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch之 sava_reload_model相关的知识,希望对你有一定的参考价值。

 1 import torch
 2 import matplotlib.pyplot as plt
 3 
 4 # torch.manual_seed(1)    # reproducible
 5 
 6 # fake data
 7 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
 8 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 9 
10 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
11 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
12 
13 
14 def save():
15     # save net1
16     net1 = torch.nn.Sequential(
17         torch.nn.Linear(1, 10),
18         torch.nn.ReLU(),
19         torch.nn.Linear(10, 1)
20     )
21     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
22     loss_func = torch.nn.MSELoss()
23 
24     for t in range(100):
25         prediction = net1(x)
26         loss = loss_func(prediction, y)
27         optimizer.zero_grad()
28         loss.backward()
29         optimizer.step()
30 
31     # plot result
32     plt.figure(1, figsize=(10, 3))
33     plt.subplot(131)
34     plt.title(Net1)
35     plt.scatter(x.data.numpy(), y.data.numpy())
36     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
37 
38     # 2 ways to save the net
39     torch.save(net1, net.pkl)  # save entire net
40     torch.save(net1.state_dict(), net_params.pkl)   # save only the parameters
41 
42 
43 def restore_net():
44     # restore entire net1 to net2
45     net2 = torch.load(net.pkl)
46     prediction = net2(x)
47 
48     # plot result
49     plt.subplot(132)
50     plt.title(Net2)
51     plt.scatter(x.data.numpy(), y.data.numpy())
52     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
53 
54 
55 def restore_params():
56     # restore only the parameters in net1 to net3
57     net3 = torch.nn.Sequential(
58         torch.nn.Linear(1, 10),
59         torch.nn.ReLU(),
60         torch.nn.Linear(10, 1)
61     )
62 
63     # copy net1‘s parameters into net3
64     net3.load_state_dict(torch.load(net_params.pkl))
65     prediction = net3(x)
66 
67     # plot result
68     plt.subplot(133)
69     plt.title(Net3)
70     plt.scatter(x.data.numpy(), y.data.numpy())
71     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
72     plt.show()
73 
74 # save net1
75 save()
76 
77 # restore entire net (may slow)
78 restore_net()
79 
80 # restore only the net parameters
81 restore_params()

 

以上是关于pytorch之 sava_reload_model的主要内容,如果未能解决你的问题,请参考以下文章

pytorch常见问题之cpu占满

pytorch之transforms.Compose()函数理解

Pytorch之线性代数

Pytorch之图像分割(多目标分割,Multi Object Segmentation)

PyTorch之人工智能学习路线

[源码解析] PyTorch 分布式之弹性训练---Rendezvous 引擎