干货|10分钟入门PyTorch~附源码

Posted 机器学习算法与自然语言处理

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了干货|10分钟入门PyTorch~附源码相关的知识,希望对你有一定的参考价值。


10分钟入门PyTorch(2)


上一节介绍了简单的线性回归,如何在pytorch里面用最小二乘来拟合一些离散的点,这一节我们将开始简单的logistic回归,介绍图像分类问题,使用的数据是手写字体数据集MNIST。

1
  logistic回归


logistic回归简单来说和线性回归是一样的,要做的运算同样是 y = w * x + b。

logistic回归简单的是做二分类问题,使用sigmoid函数将所有的正数和负数都变成0-1之间的数,这样就可以用这个数来确定到底属于哪一类,可以简单的认为概率大于0.5即为第二类,小于0.5为第一类。


这就是sigmoid的图形

干货|10分钟入门PyTorch(2)~附源码


而我们这里要做的是多分类问题,对于每一个数据,我们输出的维数是分类的总数,比如10分类,我们输出的就是一个10维的向量,然后我们使用另外一个激活函数,softmax

干货|10分钟入门PyTorch(2)~附源码

这就是softmax函数作用的机制,其实简单的理解就是确定这10个数每个数对应的概率有多大,因为这10个数有正有负,所以通过指数函数将他们全部变成正数,然后求和,然后这10个数每个数都除以这个和,这样就得到了每个类别的概率。


data


首先导入torch里面专门做图形处理的一个库,torchvision,根据官方安装指南,你在安装pytorch的时候torchvision也会安装。

我们需要使用的是torchvision.transforms和torchvision.datasets以及torch.utils.data.DataLoader


首先DataLoader是导入图片的操作,里面有一些参数,比如batch_size和shuffle等,默认load进去的图片类型是PIL.Image.open的类型,如果你不知道PIL,简单来说就是一种读取图片的库


torchvision.transforms里面的操作是对导入的图片做处理,比如可以随机取(50, 50)这样的窗框大小,或者随机翻转,或者去中间的(50, 50)的窗框大小部分等等,但是里面必须要用的是transforms.ToTensor(),这可以将PIL的图片类型转换成tensor,这样pytorch才可以对其做处理


torchvision.datasets里面有很多数据类型,里面有官网处理好的数据,比如我们要使用的MNIST数据集,可以通过torchvision.datasets.MNIST()来得到,还有一个常使用的是torchvision.datasets.ImageFolder(),这个可以让我们按文件夹来取图片,和keras里面的flow_from_directory()类似,具体的可以去看看官方文档的介绍。


干货|10分钟入门PyTorch(2)~附源码

以上就是我们对图片数据的读取操作

model


之前讲过模型定义的框架,废话不多说,直接上代码

干货|10分钟入门PyTorch(2)~附源码

我们需要向这个模型传入参数,第一个参数定义为数据的维度第二维数是我们分类的数目。


接着我们可以在gpu上跑模型,怎么做呢?
首先可以判断一下你是否能在gpu上跑


干货|10分钟入门PyTorch(2)~附源码

如果返回True就说明有gpu支持

接着你只需要一个简单的命令就可以了


干货|10分钟入门PyTorch(2)~附源码


或者


干货|10分钟入门PyTorch(2)~附源码


都可以

然后需要定义loss和optimizer


干货|10分钟入门PyTorch(2)~附源码


这里我们使用的loss是交叉熵,是一种处理分类问题的loss,optimizer我们还是使用随机梯度下降


train

接着就可以开始训练了


干货|10分钟入门PyTorch(2)~附源码

干货|10分钟入门PyTorch(2)~附源码


注意我们如果将模型放到了gpu上,相应的我们的Variable也要放到gpu上,也很简单


干货|10分钟入门PyTorch(2)~附源码


然后可以测试模型,过程与训练类似,只是注意要将模型改成测试模式


干货|10分钟入门PyTorch(2)~附源码


这是跑完100 epoch的结果



具体的结果多久打印一次,如何打印可以自己在for循环里面去设计


这一部分我们就讲解了如何用logistic回归去做一个简单的图片分类问题,知道了如何在gpu上跑模型,下一节我们将介绍如何写简单的卷积神经网络,不了解卷积网络的同学可以先去我的专栏看看之前卷积网络的介绍。


本文代码已经上传到了github

欢迎查看我的知乎专栏,深度炼丹

欢迎访问我的博客


推荐阅读文章:




全是通俗易懂的硬货!只需置顶~欢迎关注交流~


以上是关于干货|10分钟入门PyTorch~附源码的主要内容,如果未能解决你的问题,请参考以下文章

深度学习之30分钟快速入门PyTorch(附学习资源推荐)

速学10分钟让你了解PyTorch框架(附代码)

10分钟快速入门PyTorch

书籍深度学习框架:PyTorch入门与实践(附代码)

教你用PyTorch实现“看图说话”(附代码学习资源)

PyTorch从入门到精通100讲-Pytorch Geometric 从原理到实战应用案例(附代码)