[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970
目录
第1章 TorchVision概述
1.1 TorchVision
Pytorch非常有用的工具集:
- torchtext:处理自然语言
- torchaudio:处理音频的
- torchvision:处理图像视频的。
torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。
1.2 TorchVision的安装
pip install torchvision
1.3 TorchVision官网的数据集
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
1.4 TorchVision常见的数据集概述
- MNIST
- CIFAR10
- CIFAR100
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- STL10
第2章 CIFAR10数据集
2.1 数据集概述
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。
该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。
其中,50000张用于训练,构成了5个训练批次,每一批10000张图;
其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。
注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
CIFAR-10 的图片样例如图所示,包括
飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
2.2 与 MNIST 数据集比较
与 MNIST 数据集比较, CIFAR-10 具有以下不同点:
- CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
- CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
- 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
- 直接的全连接的线性模型,即使在MNIST表现良好,在 CIFAR-10数据集上表现得很差。
2.3 下载地址
官方下载地址:(很慢)
一共有三个版本:python,matlab,binary version 适用于C语言
http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
第3章 TorchVision对CIFAR10的支持
3.1 函数原型
CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)
- root:存储数据集的根目录
- train=True or false:训练集还是测试集
- transform=None:在加载数据前的格式转换
- target_transform=None:
- download=False:是否需要在线下载
3.2 数据下载前的准备
#环境准备
import numpy as np # numpy数组库
import math # 数学运算库
import matplotlib.pyplot as plt # 画图库
import torch # torch基础库
import torchvision.datasets as dataset #公开数据集的下载和管理
import torchvision.transforms as transforms #公开数据集的预处理库,格式转换
import torchvision.utils as utils
import torch.utils.data as data_utils #对数据集进行分批加载的工具集
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
3.3 数据集下载与导入
如果本地没有数据集,会自动远程下载
#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
train = True,
transform = transforms.ToTensor(),
download = True)
#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
train = False,
transform = transforms.ToTensor(),
download = True)
print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\\cifar-10-python.tar.gz Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\\cifar-10-python.tar.gz
100.0%
Extracting cifar10\\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\\MNIST\\raw\\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\\MNIST\\raw\\train-images-idx3-ubyte.gz to cifar10\\MNIST\\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw\\train-labels-idx1-ubyte.gz Extracting cifar10\\MNIST\\raw\\train-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\\MNIST\\raw\\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to cifar10\\MNIST\\raw Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw Processing...
C:\\ProgramData\\Anaconda3\\envs\\pytorch1.8_py3.8\\lib\\site-packages\\torchvision\\datasets\\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done! Dataset CIFAR10 Number of datapoints: 50000 Root location: cifar10 Split: Train StandardTransform Transform: ToTensor() size= 50000 Dataset MNIST Number of datapoints: 10000 Root location: cifar10 Split: Test StandardTransform Transform: ToTensor() size= 10000
3.4 显示单张样本图片
#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0) #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)
plt.imshow(image)
plt.show()
3.5 启动loader对象
# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
batch_size = 8,
shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
batch_size = 8,
shuffle = True)
print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data), len(test_data)/8)
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA85E0> <torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA8F40> 50000 6250.0 10000 1250.0
3.6 显示批量图片
pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。
因此需要通过transpose()进行维度的变换。
#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])
print("\\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)
print("\\n转换成imshow格式")
images = images.numpy().transpose(1,2,0)
print(images.shape)
print(labels.shape)
print("\\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片 torch.Size([8, 3, 32, 32]) torch.Size([8]) 8 合并成一张三通道灰度图片 torch.Size([3, 70, 138]) torch.Size([8]) 转换成imshow格式 (70, 138, 3) torch.Size([8]) 显示图片
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970
以上是关于[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-36]:数据集 - torchvision与ImageNet数据集
[Pytorch系列-37]: 工具集 - torchvision库详解(数据集数据预处理模型)