神经网络和深度学习-用pytorch实现线性回归

Posted Ricardo_PING_

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了神经网络和深度学习-用pytorch实现线性回归相关的知识,希望对你有一定的参考价值。

用pytorch实现线性回归

用pytorch的工具包来实现线性模型的训练过程

  • 准备数据集

  • 设计模型

  • 构造损失函数和优化器(使用pytorch API)

  • 训练过程:前馈、反馈、更新

准备数据

在PyTorch中,计算图是小批处理的,所以X和Y是3x1 Tensors。(X和Y的值必须是个矩阵)

利用广播将矩阵进行衍生

同时损失函数也是这样进行计算的

设计模型

采用仿射模型Affine Model(线性单元Linear Unit) 来构造计算图

在线性单元中我们想要知道权重的维度大小,首先需要知道输入x的维度,以及输出y的维度,之后放入loss函数中计算损失值,最后进行backward来对整个计算图进行反向传播

我们要得到loss的标量,可以进行求和或mean,如果最后求出来的是向量的话,是不能进行backward操作

接下来我们设计模型,我们的模型类应该继承于nn.Module,它是所有神经网络模块的Baseclass

在这个类中我们需要实现两个函数,一个是构造函数init,另一个是前馈函数forward,用module实现的类会自动实现backward的过程

构造函数init中,我们必须调用父类的init,下一行中实际上是在构造一个对象,包含了权重w和偏置b两个Tensors,其中Linear也是继承自Module

我们来看一下Linear这个类,详细介绍一下

  • in_features: 输入样本的维度,在mini-batch中的矩阵中,行表示样本,列表示features

  • out_features: 输出样本的维度

  • bias:是一个布尔类型的参数,用来决定是否有偏置

我们在实际的运算中,是进行转置计算的

forward函数,self.linear(x)实现了一个可调用的对象

创造一个实例类,model是可以被直接调用的,model(x)

构造损失函数和优化器

criterion: 其中MSELoss也是继承于Module父类中的,其中有两个参数

  • size_average:样本是否要求均值

  • reduce:是否要进行降维

接下来是优化器optimizer。其中使用的是SGD里面有这么几个参数

  • params:检查model里面的所有成员,成员里面有相应的权重,把这些都加到训练的结果上,调用所有权重值:linear.parameters()

  • lr:learning rate 学习率,可以设置不同的学习率

训练过程

按照model计算y_pred 计算loss函数 梯度归0 backward 最后进行更新

最后我们进行输出权重和偏置,在设置测试模型

在训练中我们发现迭代100次训练不能达到要求时,我们可以进行1000次训练

完整代码演示

import torch
import matplotlib.pyplot as plt

# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])


# design model using class
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = LinearModel()

# construct loss and optimizer
criterion =torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

epoch_list = []
loss_list = []
# training cycle forward, backward, update
for epoch in range(1000):
    y_pred = model(x_data)  # forward:predict
    loss = criterion(y_pred,y_data)  # forward: loss
    print(epoch,loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())

    optimizer.zero_grad()
    loss.backward()  # backward: autograd,自动计算梯度
    optimizer.step()  # update 参数,即更新w和b的值

print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

x_test = torch.Tensor([4.0])
y_test = model(x_test)
print('y_pred = ',y_test.data)

plt.plot(epoch_list, loss_list)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

训练结果

0 112.35303497314453
1 50.16391372680664
2 22.476961135864258
3 10.149425506591797
4 4.659496784210205
5 2.213507890701294
6 1.1226205825805664
7 0.6350162029266357
8 0.41600528359413147
9 0.316591739654541
10 0.27044761180877686
11 0.24804410338401794
12 0.2362363487482071
13 0.22917166352272034
14 0.22424422204494476
15 0.2202942818403244
16 0.216804638504982
17 0.21354466676712036
18 0.2104116678237915
19 0.20735913515090942
20 0.2043665200471878
21 0.20142365992069244
22 0.19852635264396667
23 0.1956721395254135
24 0.19285938143730164
25 0.1900874376296997
26 0.18735529482364655
27 0.1846628487110138
28 0.18200890719890594
29 0.1793932020664215
30 0.1768149435520172
31 0.17427413165569305
32 0.17176926136016846
33 0.16930058598518372
34 0.1668677031993866
35 0.1644694209098816
36 0.16210560500621796
37 0.15977603197097778
38 0.15747982263565063
39 0.1552167683839798
40 0.15298616886138916
41 0.15078730881214142
42 0.1486203372478485
43 0.1464843451976776
44 0.14437904953956604
45 0.1423041820526123
46 0.14025893807411194
47 0.13824328780174255
48 0.1362564116716385
49 0.13429833948612213
50 0.132368266582489
51 0.13046585023403168
52 0.12859109044075012
53 0.1267428994178772
54 0.12492141127586365
55 0.1231260746717453
56 0.12135656177997589
57 0.11961238831281662
58 0.11789334565401077
59 0.11619921773672104
60 0.11452922224998474
61 0.11288312822580338
62 0.11126099526882172
63 0.10966184735298157
64 0.10808579623699188
65 0.106532521545887
66 0.10500149428844452
67 0.10349243879318237
68 0.10200508683919907
69 0.10053901374340057
70 0.09909424185752869
71 0.09766990691423416
72 0.09626630693674088
73 0.09488292038440704
74 0.09351934492588043
75 0.09217515587806702
76 0.09085039794445038
77 0.08954479545354843
78 0.08825798332691193
79 0.08698944002389908
80 0.08573935925960541
81 0.0845070481300354
82 0.08329254388809204
83 0.08209548890590668
84 0.08091577142477036
85 0.07975286990404129
86 0.07860668748617172
87 0.07747688889503479
88 0.07636338472366333
89 0.07526612281799316
90 0.07418430596590042
91 0.07311809062957764
92 0.07206742465496063
93 0.07103154063224792
94 0.07001075148582458
95 0.06900466978549957
96 0.06801298260688782
97 0.06703557074069977
98 0.066072016954422
99 0.06512241810560226
100 0.06418654322624207
101 0.06326410919427872
102 0.062355078756809235
103 0.06145886704325676
104 0.060575492680072784
105 0.05970491096377373
106 0.058846890926361084
107 0.05800115689635277
108 0.05716757848858833
109 0.056346092373132706
110 0.0555361732840538
111 0.054738033562898636
112 0.053951431065797806
113 0.0531761460006237
114 0.05241178348660469
115 0.05165863409638405
116 0.05091621354222298
117 0.05018452927470207
118 0.04946324974298477
119 0.048752374947071075
120 0.04805169627070427
121 0.047361113131046295
122 0.04668048769235611
123 0.04600962996482849
124 0.045348361134529114
125 0.0446966178715229
126 0.044054266065359116
127 0.04342120513319969
128 0.04279709234833717
129 0.042182084172964096
130 0.0415758416056633
131 0.040978316217660904
132 0.040389373898506165
133 0.03980895131826401
134 0.03923676908016205
135 0.038672927767038345
136 0.03811722993850708
137 0.037569429725408554
138 0.03702935576438904
139 0.03649728000164032
140 0.03597274795174599
141 0.035455767065286636
142 0.0349462553858757
143 0.034443940967321396
144 0.03394889459013939
145 0.03346104547381401
146 0.03298007696866989
147 0.032506152987480164
148 0.032039009034633636
149 0.03157859295606613
150 0.031124703586101532
151 0.030677346512675285
152 0.030236493796110153
153 0.029801972210407257
154 0.029373735189437866
155 0.028951536864042282
156 0.028535399585962296
157 0.028125345706939697
158 0.027721144258975983
159 0.027322763577103615
160 0.026930106803774834
161 0.026543039828538895
162 0.02616160549223423
163 0.025785554200410843
164 0.02541501820087433
165 0.02504979632794857
166 0.024689771234989166
167 0.02433495968580246
168 0.02398517169058323
169 0.023640474304556847
170 0.02330072969198227
171 0.022965876385569572
172 0.022635798901319504
173 0.02231057733297348
174 0.021989917382597923
175 0.021673863753676414
176 0.021362382918596268
177 0.021055368706583977
178 0.02075282670557499
179 0.02045450732111931
180 0.020160524174571037
181 0.019870825111865997
182 0.019585274159908295
183 0.01930379308760166
184 0.019026298075914383
185 0.01875285804271698
186 0.018483413383364677
187 0.0182177871465683
188 0.01795593649148941
189 0.01769784465432167
190 0.017443565651774406
191 0.01719282567501068
192 0.0169457346200943
193 0.016702204942703247
194 0.016462240368127823
195 0.016225600615143776
196 0.015992436558008194
197 0.015762610360980034
198 0.015536056831479073
199 0.01531275361776352
200 0.01509274821728468
201 0.014875786378979683
202 0.014661986380815506
203 0.014451313763856888
204 0.014243602752685547
205 0.01403888314962387
206 0.013837129808962345
207 0.013638310134410858
208 0.013442250899970531
209 0.01324906200170517
210 0.013058701530098915
211 0.012870989739894867
212 0.012686015106737614
213 0.01250374224036932
214 0.012324041686952114
215 0.012146896682679653
216 0.011972364969551563
217 0.011800271458923817
218 0.011630683206021786
219 0.011463530361652374
220 0.011298784986138344
221 0.011136380955576897
222 0.010976347140967846
223 0.01081860065460205
224 0.010663095861673355
225 0.010509897023439407
226 0.010358813218772411
227 0.010210014879703522
228 0.010063203051686287
229 0.009918630123138428
230 0.009776086546480656
231 0.00963554996997118
232 0.009497096762061119
233 0.00936061330139637
234 0.009226078167557716
235 0.009093508124351501
236 0.008962801657617092
237 0.008833975531160831
238 0.008707026019692421
239 0.008581867441534996
240 0.008458593860268593
241 0.008336987346410751
242 0.00821713637560606
243 0.008099078200757504
244 0.007982665672898293
245 0.007867971435189247
246 0.007754916325211525
247 0.007643451914191246
248 0.007533577270805836
249 0.007425309158861637
250 0.007318615913391113
251 0.007213430479168892
252 0.007109746336936951
253 0.007007563021034002
254 0.0069068484008312225
255 0.006807621102780104
256 0.006709778681397438
257 0.006613357923924923
258 0.006518302019685507
259 0.006424624007195234
260 0.006332285236567259
261 0.006241272669285536
262 0.0061515928246080875
263 0.006063209380954504
264 0.0059760636650025845
265 0.005890156142413616
266 0.005805525928735733
267 0.005722035653889179
268 0.005639835726469755
269 0.0055587803944945335
270 0.0054789199493825436
271 0.005400161724537611
272 0.005322564393281937
273 0.0052460250444710255
274 0.00517068337649107
275 0.005096335429698229
276 0.005023115314543247
277 0.004950919654220343
278 0.004879762418568134
279 0.004809624515473843
280 0.004740518983453512
281 0.004672378301620483
282 0.004605223890393972
283 0.004539061337709427
284 0.004473820794373751
285 0.004409499932080507
286 0.00434613972902298
287 0.004283687099814415
288 0.004222120624035597
289 0.004161442629992962
290 0.004101647064089775
291 0.004042675718665123
292 0.003984586801379919
293 0.0039273155853152275
294 0.00387090677395463
295 0.003815251402556896
296 0.003760433988645673
297 0.0037063637282699347
298 0.0036531207151710987
299 0.0036005976144224405
300 0.003548881271854043
301 0.0034978655166924
302 0.0034476048313081264
303 0.0033980298321694136
304 0.0033492010552436113
305 0.0033010661136358976
306 0.003253634786233306
307 0.003206862136721611
308 0.0031607949640601873
309 0.0031153755262494087
310 0.0030705814715474844
311 0.0030264602974057198
312 0.0029829726554453373
313 0.002940083621069789
314 0.0028978462796658278
315 0.0028562003280967474
316 0.002815158339217305
317 0.002774682827293873
318 0.0027347952127456665
319 0.002695516450330615
320 0.0026567645836621523
321 0.0026185819879174232
322 0.0025809709914028645
323 0.0025438638404011726
324 0.0025072875432670116
325 0.0024712560698390007
326 0.0024357298389077187
327 0.002400750992819667
328 0.0023662373423576355
329 0.002332223579287529
330 0.0022987292613834143
331 0.002265690127387643
332 0.0022331103682518005
333 0.0022010307293385267
334 0.0021693864837288857
335 0.002138205338269472
336 0.002107476582750678
337 0.002077211393043399
338 0.0020473389886319637
339 0.00201791781000793
340 0.0019889171235263348
341 0.001960339955985546
342 0.0019321611616760492
343 0.0019044021610170603
344 0.001877025468274951
345 0.0018500544829294086
346 0.0018234821036458015
347 0.0017972624627873302
348 0.0017714430578052998
349 0.0017459901282563806
350 0.0017208807403221726
351 0.0016961493529379368
352 0.0016717755934223533
353 0.0016477338504046202
354 0.0016240653349086642
355 0.0016007355879992247
356 0.0015777155058458447
357 0.001555052469484508
358 0.0015327057335525751
359 0.001510667847469449
360 0.0014889667509123683
361 0.0014675650745630264
362 0.0014464533887803555
363 0.001425670343451202
364 0.001405192306265235
365 0.0013850030954927206
366 0.0013650888577103615
367 0.0013454807922244072
368 0.0013261355925351381
369 0.0013070752611383796
370 0.0012883051531389356
371 0.001269784988835454
372 0.0012515324633568525
373 0.0012335367500782013
374 0.0012158092577010393
375 0.0011983568547293544
376 0.0011811066651716828
377 0.0011641462333500385
378 0.0011474225902929902
379 0.0011309236288070679
380 0.0011146781034767628
381 0.001098650274798274
382 0.001082862843759358
383 0.0010672922944650054
384 0.0010519740171730518
385 0.0010368265211582184
386 0.0010219486430287361
387 0.0010072538862004876
388 0.0009927719365805387
389 0.0009785002330318093
390 0.0009644600213505328
391 0.0009505984489805996
392 0.0009369404288008809
393 0.0009234536555595696
394 0.0009101927280426025
395 0.0008971025235950947
396 0.0008842144161462784
397 0.0008715190342627466
398 0.0008589837234467268
399 0.0008466390427201986
400 0.0008344692178070545
401 0.0008224759949371219
402 0.0008106539607979357
403 0.0007990046287886798
404 0.000787514669355005
405 0.0007762035238556564
406 0.0007650482002645731
407 0.0007540533551946282
408 0.0007432110724039376
409 0.0007325401529669762
410 0.0007220103871077299
411 0.0007116313790902495
412 0.0007013983558863401
413 0.0006913295364938676
414 0.0006813958170823753
415 0.0006715868366882205
416 0.0006619489868171513
417 0.0006524224299937487
418 0.000643059378489852
419 0.0006338122766464949
420 0.0006247149431146681
421 0.00061571947298944
422 0.0006068808725103736
423 0.0005981545546092093
424 0.0005895563517697155
425 0.0005810848670080304
426 0.0005727456882596016
427 0.000564498535823077
428 0.000556396204046905
429 0.0005483878776431084
430 0.0005405200645327568
431 0.0005327520193532109
432 0.0005250974791124463
433 0.0005175354308448732
434 0.0005101095885038376
435 0.0005027778679504991
436 0.0004955431795679033
437 0.0004884239169768989
438 0.00048141099978238344
[深度学习][pytorch]pytorch实现一个简单得线性回归模型并训练

《动手学深度学习》线性回归(PyTorch版)

《动手学深度学习》线性回归的简洁实现(linear-regression-pytorch)

《动手学深度学习》线性回归的简洁实现(linear-regression-pytorch)

翻译: 3.3. 线性回归的简明实现 pytorch

翻译: 3.2. 从零开始实现线性回归 深入神经网络 pytorch