Pytorch 基本使用(数据加载,类型转换)
Posted 'or 1 or 不正经の泡泡
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 基本使用(数据加载,类型转换)相关的知识,希望对你有一定的参考价值。
文章目录
本博文优先在掘金社区发布!
前言
通过前面的一些介绍的话,我们大概知道了我们的pytorch的tensor的一些基本概念,还有咱们梯度和tensor复制时的一些细节,tensor和numpy在很大程度上很像,在某些场合我们甚至可以直接使用tensor来进行运算。那么现在我们来说说pytorch的一些基本使用。
毕竟我们使用pytorch是用来搭建我们的神经网络,进行深度学习的。那么在机器学习小概述里面说过,深度学习其实也是我们机器学习的一种分支,也就是特殊一点的机器学习。那么先前sklearn的机器学习和aruze的云平台的机器学习步骤大致分为了五部曲,那么在pytorch里面其实也是类似,只是算法部分换成了比较抽象的神经网络。
所以我们可以把pytorch大致分为这几块
那么在这里主要将的是数据的加载,与转换。
类型转换
一开始我们说了,tensor可以将numpy的数据进行转换,但是有时候我们需要处理的可能是文本,或者图片,声音。所以我们需要一个转换器(当然你也可以转成numpy然后再转换为tensor但是那闲的慌才那么干)
这里使用工具包
tensorvision
例如我们对图片进行转换。
我们发现这工具包下面还有很多内容,Totensor()可以直接进行转换(看见源码还有说明)
在这里我们就轻松完成了转化。
Compose “链式转化”
有时候我们可能需要进行多次转化,例如我们需要对一个图片先改变尺寸,然后进行转化。那么这个时候为了避免重复代码,所以此时我们还能这样做。
from torchvision import transforms
tensor_to = transforms.ToTensor()
compose = transforms.Compose([tensor_to,])
image = Image.open("train/1/0BGHNV6P.jpg")
img = compose(image)
print(img)
那么这里还有其他的方式,我就不说了,你pycharm一点都出来了,还有注释。
类型转换的话其实很简单,而且对应的情况也比较多,我这边真不好说明。
数据处理
我们都知道,机器学习其实都离不开数据,数据集。那么对于一些比较出名的网络模型,或者是数据集,在pytorch里面都提供了自动下载的工具。
自带数据集
这个是指pytorch会自动通过爬虫来下载数据集合,然后给我们封装好。
这个也是使用tenssorvision
例如下载CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
tese_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
直接搞的,但是注意的是,这里得到的数据集不是tensor类型的,我们还要进行类型转换
from torchvision import transforms
trans = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans,download=True)
数据加载
之后就是我们的数据加载
这里的话使用的就是utils下面的工具了
from torch.utils.data import DataLoader
from torchvision import transforms
trans = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans,download=True)
dataloader = DataLoader(dataset,batch_size=64)
这里主要介绍一些DataLoader的参数。
自定义获取数据
这个的话就比较原始,就是有时候我们需要自己加载数据集,举个例子。
这个就是从网上下载的数据集,现在要把这个导入到我们的pytorch里面。
这个文件夹是标签名,在这个数据集里面。1 是 1块钱的图片 100是一百块钱的图片。
我这里先直接给出代码
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import os
from PIL import Image
# 通过Dataset来获取数据
class MyDataset(Dataset):
def __init__(self,RootDir,LabelDir):
self.RootDir = RootDir
self.LabelDir = LabelDir
self.transform = transforms.ToTensor()
self.ImagePathDir = os.path.join(self.RootDir,self.LabelDir)
self.ImageNameItems = os.listdir(self.ImagePathDir)
def __getitem__(self, item):
# item 是获取某一个数据元素,懒汉模式,你要用我才给你
ItemName = self.ImageNameItems[item]
ImagePathItem = os.path.join(self.RootDir,self.LabelDir,ItemName)
ItemGet = self.transform(Image.open(ImagePathItem).resize((500,500)))
ItemLabel = self.LabelDir
return ItemGet,ItemLabel
def __len__(self):
return len(self.ImageNameItems)
if __name__ =="__main__":
RootDir = "train"
OneYuanLabel = "1"
HandoneYuanLabel = "100"
OneYuanData = MyDataset(RootDir,OneYuanLabel)
HandoneData = MyDataset(RootDir,HandoneYuanLabel)
DataGet = OneYuanData+HandoneData
train_data = DataLoader(dataset=DataGet,batch_size=18,shuffle=True,num_workers=0,drop_last=True)
for data in train_data:
imgs,tags = data
print(imgs.shape)
重点是我们进行那个 继承Dataset,然后实现 __getitem()__这个魔法方法。看代码其实很简单,获取了我们路径的图片名称,然后,再调用魔法方法的时候,我们读取图片然后直接转化为tensor,这个其实和前面获取的数据是类似的,只是我们直接转化了一下,同时这里也是为什么我们要用 DataLoader,用这个可以把数据拿出来,而不是等到训练模型的时候再来,那样是很慢的。
总结
这些就是最基本的操作,那么明天就说是那个怎么玩神经网络,使用pytorch。这边我们还是以CNN为例,搭建一下CIFAR10这个模型。后面我们再做一个小demo。
其实关于pytorch的使用是非常简单的,但是有很多先决条件,不然这个很难理解,这个不像那个python的django java的ssm springcloud 之类的这种业务性框架,背几个API几个注解就OK了, 那个东西上手很快,当然源码就另说了。
以上是关于Pytorch 基本使用(数据加载,类型转换)的主要内容,如果未能解决你的问题,请参考以下文章
Java程序员学深度学习 DJL上手7 使用Pytorch引擎
加载器的无效数据类型 - Pytorch Lightning DataModule
小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext