[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解

Posted 文火冰糖的硅基工坊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解相关的知识,希望对你有一定的参考价值。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 CIFAR10数据集

2.1 数据集概述

2.2 与 MNIST 数据集比较

2.3 下载地址

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

3.2 数据下载前的准备

3.3 数据集下载与导入

3.4 显示单张样本图片

3.5 启动loader对象

3.6 显示批量图片


第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • STL10

第2章 CIFAR10数据集

2.1 数据集概述

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。

该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。

其中,50000张用于训练,构成了5个训练批次,每一批10000张图;

其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。

注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

CIFAR-10 的图片样例如图所示,包括

飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。

2.2 与 MNIST 数据集比较

与 MNIST 数据集比较, CIFAR-10 具有以下不同点:

  • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
  • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
  • 直接的全连接的线性模型,即使在MNIST表现良好,在 CIFAR-10数据集上表现得很差。

2.3 下载地址

官方下载地址:(很慢)

一共有三个版本:python,matlab,binary version 适用于C语言

http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)

  • root:存储数据集的根目录
  • train=True or false:训练集还是测试集
  • transform=None:在加载数据前的格式转换
  • target_transform=None:
  • download=False:是否需要在线下载

3.2 数据下载前的准备

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())

3.3 数据集下载与导入

如果本地没有数据集,会自动远程下载

#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\\cifar-10-python.tar.gz
100.0%
Extracting cifar10\\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\\MNIST\\raw\\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\\MNIST\\raw\\train-images-idx3-ubyte.gz to cifar10\\MNIST\\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw\\train-labels-idx1-ubyte.gz
Extracting cifar10\\MNIST\\raw\\train-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\\MNIST\\raw\\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to cifar10\\MNIST\\raw
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to cifar10\\MNIST\\raw
Processing...
C:\\ProgramData\\Anaconda3\\envs\\pytorch1.8_py3.8\\lib\\site-packages\\torchvision\\datasets\\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done!
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: cifar10
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 50000

Dataset MNIST
    Number of datapoints: 10000
    Root location: cifar10
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

3.4 显示单张样本图片

#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0)  #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)

plt.imshow(image)
plt.show()

3.5 启动loader对象

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 8,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 8,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data),  len(test_data)/8)
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA85E0>
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA8F40>
50000 6250.0
10000 1250.0

3.6 显示批量图片

pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。

因此需要通过transpose()进行维度的变换。

#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)

print("\\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([8, 3, 32, 32])
torch.Size([8])
8

合并成一张三通道灰度图片
torch.Size([3, 70, 138])
torch.Size([8])

转换成imshow格式
(70, 138, 3)
torch.Size([8])

显示图片


作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970

以上是关于[Pytorch系列-33]:数据集 - torchvision与CIFAR10详解的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch: 张量的拼接切分索引

[Pytorch系列-36]:数据集 - torchvision与ImageNet数据集

[Pytorch系列-37]: 工具集 - torchvision库详解(数据集数据预处理模型)

如何创建图神经网络数据集? (pytorch 几何)

[Pytorch系列-45]:卷积神经网络 - 用GPU训练AlexNet+CIFAR10数据集

[Pytorch系列-46]:卷积神经网络 - 用GPU训练ResNet+CIFAR100数据集