Pytorch 基本使用(数据加载,类型转换)

Posted 'or 1 or 不正经の泡泡

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 基本使用(数据加载,类型转换)相关的知识,希望对你有一定的参考价值。

文章目录


本博文优先在掘金社区发布!

前言

通过前面的一些介绍的话,我们大概知道了我们的pytorch的tensor的一些基本概念,还有咱们梯度和tensor复制时的一些细节,tensor和numpy在很大程度上很像,在某些场合我们甚至可以直接使用tensor来进行运算。那么现在我们来说说pytorch的一些基本使用。

毕竟我们使用pytorch是用来搭建我们的神经网络,进行深度学习的。那么在机器学习小概述里面说过,深度学习其实也是我们机器学习的一种分支,也就是特殊一点的机器学习。那么先前sklearn的机器学习和aruze的云平台的机器学习步骤大致分为了五部曲,那么在pytorch里面其实也是类似,只是算法部分换成了比较抽象的神经网络。

所以我们可以把pytorch大致分为这几块


那么在这里主要将的是数据的加载,与转换。

类型转换

一开始我们说了,tensor可以将numpy的数据进行转换,但是有时候我们需要处理的可能是文本,或者图片,声音。所以我们需要一个转换器(当然你也可以转成numpy然后再转换为tensor但是那闲的慌才那么干)

这里使用工具包

tensorvision

例如我们对图片进行转换。


我们发现这工具包下面还有很多内容,Totensor()可以直接进行转换(看见源码还有说明)


在这里我们就轻松完成了转化。

Compose “链式转化”

有时候我们可能需要进行多次转化,例如我们需要对一个图片先改变尺寸,然后进行转化。那么这个时候为了避免重复代码,所以此时我们还能这样做。


from torchvision import transforms

tensor_to = transforms.ToTensor()
compose = transforms.Compose([tensor_to,])
image = Image.open("train/1/0BGHNV6P.jpg")

img = compose(image)
print(img)


那么这里还有其他的方式,我就不说了,你pycharm一点都出来了,还有注释。
类型转换的话其实很简单,而且对应的情况也比较多,我这边真不好说明。

数据处理

我们都知道,机器学习其实都离不开数据,数据集。那么对于一些比较出名的网络模型,或者是数据集,在pytorch里面都提供了自动下载的工具。

自带数据集

这个是指pytorch会自动通过爬虫来下载数据集合,然后给我们封装好。
这个也是使用tenssorvision
例如下载CIFAR10数据集

train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
tese_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

直接搞的,但是注意的是,这里得到的数据集不是tensor类型的,我们还要进行类型转换

from torchvision import transforms

trans = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans,download=True)

数据加载

之后就是我们的数据加载
这里的话使用的就是utils下面的工具了

from torch.utils.data import DataLoader
from torchvision import transforms

trans = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans,download=True)

dataloader = DataLoader(dataset,batch_size=64)

这里主要介绍一些DataLoader的参数。

自定义获取数据

这个的话就比较原始,就是有时候我们需要自己加载数据集,举个例子。

这个就是从网上下载的数据集,现在要把这个导入到我们的pytorch里面。
这个文件夹是标签名,在这个数据集里面。1 是 1块钱的图片 100是一百块钱的图片。

我这里先直接给出代码

from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import os
from PIL import Image

# 通过Dataset来获取数据
class MyDataset(Dataset):

    def __init__(self,RootDir,LabelDir):
        self.RootDir = RootDir

        self.LabelDir = LabelDir
        self.transform = transforms.ToTensor()
        self.ImagePathDir = os.path.join(self.RootDir,self.LabelDir)
        self.ImageNameItems = os.listdir(self.ImagePathDir)

    def __getitem__(self, item):
        # item 是获取某一个数据元素,懒汉模式,你要用我才给你
        ItemName = self.ImageNameItems[item]
        ImagePathItem = os.path.join(self.RootDir,self.LabelDir,ItemName)
        ItemGet = self.transform(Image.open(ImagePathItem).resize((500,500)))
        ItemLabel = self.LabelDir
        return ItemGet,ItemLabel

    def __len__(self):
        return len(self.ImageNameItems)




if __name__ =="__main__":
    RootDir = "train"
    OneYuanLabel = "1"
    HandoneYuanLabel = "100"
    OneYuanData = MyDataset(RootDir,OneYuanLabel)
    HandoneData = MyDataset(RootDir,HandoneYuanLabel)

    DataGet  = OneYuanData+HandoneData


    train_data = DataLoader(dataset=DataGet,batch_size=18,shuffle=True,num_workers=0,drop_last=True)

    for data in train_data:
        imgs,tags = data
        print(imgs.shape)



重点是我们进行那个 继承Dataset,然后实现 __getitem()__这个魔法方法。看代码其实很简单,获取了我们路径的图片名称,然后,再调用魔法方法的时候,我们读取图片然后直接转化为tensor,这个其实和前面获取的数据是类似的,只是我们直接转化了一下,同时这里也是为什么我们要用 DataLoader,用这个可以把数据拿出来,而不是等到训练模型的时候再来,那样是很慢的。

总结

这些就是最基本的操作,那么明天就说是那个怎么玩神经网络,使用pytorch。这边我们还是以CNN为例,搭建一下CIFAR10这个模型。后面我们再做一个小demo。

其实关于pytorch的使用是非常简单的,但是有很多先决条件,不然这个很难理解,这个不像那个python的django java的ssm springcloud 之类的这种业务性框架,背几个API几个注解就OK了, 那个东西上手很快,当然源码就另说了。

以上是关于Pytorch 基本使用(数据加载,类型转换)的主要内容,如果未能解决你的问题,请参考以下文章

Java程序员学深度学习 DJL上手7 使用Pytorch引擎

加载器的无效数据类型 - Pytorch Lightning DataModule

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext

PyTorch常用知识总结

PyTorch常用知识总结