Pytorch使用Pytorch简单实现一个线性模型
Posted 海轰Pro
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch使用Pytorch简单实现一个线性模型相关的知识,希望对你有一定的参考价值。
目录
简介
Hello!
非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~
ଘ(੭ˊᵕˋ)੭
昵称:海轰
标签:程序猿|C++选手|学生
简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金,有幸在竞赛中拿过一些国奖、省奖…已保研
学习经验:扎实基础 + 多做笔记 + 多敲代码 + 多思考 + 学好英语!
唯有努力💪
本文仅记录自己感兴趣的内容
前言
通过一下小例子,梳理一下Pytorch中模型的搭建、训练、测试等…
加深理解
例子1:线性模型
模型 y = 2 ∗ x y = 2 * x y=2∗x
设置的时候 就设置为 y = w ∗ x y = w * x y=w∗x
简单一点 便与理解
import torch
# 自定义一个Pytorch线性模型的类
class LinerModel(torch.nn.Module):
def __init__(self):
# super 父类,调用父类的构造,这一步必须有
# 第一个参数为定义类的名称,第二个为self
super(LinerModel, self).__init__()
'''构造一个对象,包含了权重与偏置Tensor
Linear是属于Module的,因此可以自动实现前馈和反馈的计算
nn:neural network的简写'''
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
# 实际上是类的重写,该函数实际上已经在里面被写入了,这里需要重写一下
# 当类中有"__call__"时才可调用
y_pred = self.linear(x) # 实现一个可调用的对象
return y_pred
def train(times, lr):
# 实例化类 ,该类为callable
model = LinerModel() # 是一个含有"__call__"的类,因此可以直接调用
# model(x) # x 就会送入至forward的函数里,然后对x进行计算
# 需要的数据是"\\hat y" 与"y"
criterion = torch.nn.MSELoss(size_average=False)
# 构建优化器
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
aa = []
for epoch in range(times):
# 前馈计算出 \\hat y,即y_pred
y_pred = model(x_data)
# 前馈计算出 loss
loss = criterion(y_pred, y_data)
aa.append(loss.data.item())
# 查看损失,loss输出的时候回自动变为标量
# print(epoch,loss.data.item())
# 所有的梯度归零
optimizer.zero_grad()
# 反馈,方向传播
loss.backward()
# 更新,根据所有参数和学习率来更新
optimizer.step()
# 输出权重(weight)和偏置(bias)
print(f"w=model.linear.weight.item()")
print(f"b=model.linear.bias.item()")
# 测试模型(Test Model)
x_test = torch.Tensor([[4]])
y_test = model(x_test)
print(f"y_pred=y_test.data")
return aa
if __name__ == '__main__':
name = ["SGD"]
# Pytorch中的数据等均为Tensor变量,即矩阵
x_data = torch.Tensor([[1],[2],[3]]) # 相当于一个numpy.array
y_data = torch.Tensor([[2],[4],[6]])
# 训练200次 学习率:0.01
print(train(200,0.01))
print("执行结束")
参考:https://blog.csdn.net/qq_44285092/article/details/108047578
Note
optimizer.zero_grad()的作用:
参考:https://blog.csdn.net/bigbigvegetable/article/details/114674793
参考资料
- https://blog.csdn.net/qq_44285092/article/details/108047578
- https://blog.csdn.net/bigbigvegetable/article/details/114674793
结语
文章仅作为个人学习笔记记录,记录从0到1的一个过程
希望对您有一点点帮助,如有错误欢迎小伙伴指正
以上是关于Pytorch使用Pytorch简单实现一个线性模型的主要内容,如果未能解决你的问题,请参考以下文章