线性回归的pytorch代码
Posted dataat
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了线性回归的pytorch代码相关的知识,希望对你有一定的参考价值。
使用pytorch实现的线性回归, 闲言少叙,直接上代码,客官请看:
1 import torch 2 import torch.nn as nn 3 import numpy as np 4 import matplotlib.pyplot as plt 5 6 #设置相关参数 7 input_size=1 8 output_size=1 9 num_epochs=60 10 learning_rate=0.001 11 12 #导入训练数据集 13 x_train=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168], 14 [9.779],[6.182],[7.59],[2.167],[7.042], 15 [10.791],[5.313],[7.997],[3.1]],dtype=np.float32) 16 y_train=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573], 17 [3.366],[2.596],[2.53],[1.221],[2.827], 18 [3.465],[1.65],[2.904],[1.3]],dtype=np.float32) 19 20 #设置模型、损失函数、优化函数等 21 model=nn.Linear(input_size,output_size) 22 criterion=nn.MSELoss() 23 optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate) 24 25 #开始迭代训练 26 for epoch in range(num_epochs): 27 inputs=torch.from_numpy(x_train) 28 targets=torch.from_numpy(y_train) 29 30 outputs=model(inputs) 31 loss=criterion(outputs,targets) 32 33 optimizer.zero_grad() 34 loss.backward() 35 optimizer.step() 36 37 if (epoch+1) %5==0: 38 print("Epoch [{}/{}], LossL:{:.4f}".format(epoch+1,num_epochs,loss.item())) 39 #计算出训练之后的期望值/预测值,并与实际值进行画图比较 40 predicted=model(torch.from_numpy(x_train)).detach().numpy() 41 plt.plot(x_train,y_train,‘ro‘,label=‘Original Data‘) 42 plt.plot(x_train,predicted,label=‘Fitted Line‘) 43 plt.legend() 44 plt.show() 45 #保存模型相关数据 46 torch.save(model.state_dict(),‘model.ckpt‘
-------------------- 正文到此结束------------------------
推荐一个公众号:健哥聊量化,会持续推出股票相关基础知识,以及python实现的一些基本的分析代码。欢迎大家关注,二维码如下:
相关文章列表如下:
?
以上是关于线性回归的pytorch代码的主要内容,如果未能解决你的问题,请参考以下文章