线性回归的pytorch代码

Posted dataat

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了线性回归的pytorch代码相关的知识,希望对你有一定的参考价值。

使用pytorch实现的线性回归, 闲言少叙,直接上代码,客官请看:

 1 import torch
 2 import torch.nn as  nn
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 
 6 #设置相关参数
 7 input_size=1
 8 output_size=1
 9 num_epochs=60
10 learning_rate=0.001
11 
12 #导入训练数据集
13 x_train=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],
14                   [9.779],[6.182],[7.59],[2.167],[7.042],
15                   [10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
16 y_train=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],
17                   [3.366],[2.596],[2.53],[1.221],[2.827],
18                   [3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
19 
20 #设置模型、损失函数、优化函数等
21 model=nn.Linear(input_size,output_size)
22 criterion=nn.MSELoss()
23 optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
24 
25 #开始迭代训练
26 for epoch in  range(num_epochs):
27     inputs=torch.from_numpy(x_train)
28     targets=torch.from_numpy(y_train)
29 
30     outputs=model(inputs)
31     loss=criterion(outputs,targets)
32 
33     optimizer.zero_grad()
34     loss.backward()
35     optimizer.step()
36 
37     if (epoch+1) %5==0:
38         print("Epoch [{}/{}], LossL:{:.4f}".format(epoch+1,num_epochs,loss.item()))
39 #计算出训练之后的期望值/预测值,并与实际值进行画图比较
40 predicted=model(torch.from_numpy(x_train)).detach().numpy()
41 plt.plot(x_train,y_train,ro,label=Original Data)
42 plt.plot(x_train,predicted,label=Fitted Line)
43 plt.legend()
44 plt.show()
45 #保存模型相关数据
46 torch.save(model.state_dict(),model.ckpt

 

-------------------- 正文到此结束------------------------

推荐一个公众号:健哥聊量化,会持续推出股票相关基础知识,以及python实现的一些基本的分析代码。欢迎大家关注,二维码如下:

技术图片技术图片?

相关文章列表如下:

?

以上是关于线性回归的pytorch代码的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch实现线性回归(API完成)

Pytorch实现线性回归

PyTorch学习笔记 8. 实现线性回归模型

PyTorch学习笔记 8. 实现线性回归模型

PyTorch 完全入门指南!从线性回归逻辑回归到图像分类

从零开始学PyTorch:一文学会线性回归逻辑回归及图像分类