[Pytorch系列-33]:数据集 - torchvision与MNIST数据集
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-33]:数据集 - torchvision与MNIST数据集相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489
目录
第1章 TorchVision概述
1.1 TorchVision
Pytorch非常有用的工具集:
- torchtext:处理自然语言
- torchaudio:处理音频的
- torchvision:处理图像视频的。
torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。
1.2 TorchVision的安装
pip install torchvision
1.3 TorchVision官网的数据集
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
1.4 TorchVision常见的数据集概述
- MNIST
- CIFAR10
- CIFAR100
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- STL10
第2章 MNIST数据集
2.1 MNIST数据集介绍
MNIST数据集: http://yann.lecun.com/exdb/
备注 :可以先把样本数据下载本地,以提升程序调试的效率。最终的产品可以远程下载数据。
- 每张图片大小:28*28.
- 单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)
2.2 样本数据与样本标签格式
2.3 MNIST数据的下载与导入
(1)操作函数MNIST()的解读
MNIST (root, train=True, transform=None, target_transform=None, download=False)
参数说明:
- root : 文件存放路的根路径,下载的文件存放在该路径下,processed/training.pt 和 processed/test.pt 的主目录
- train : True = 训练集, False = 测试集
- target_transform:导入数据时,是否需要对数据格式进行转换,一个函数,原始图片作为输入,返回一个转换后的图片。有时候神经网络所需要的尺寸与数据集提供的尺寸不一致,则可以通过此方法进行转换。
- download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。
(2)代码实例
#环境准备
import numpy as np # numpy数组库
import math # 数学运算库
import matplotlib.pyplot as plt # 画图库
import torch # torch基础库
import torchvision.datasets as dataset #公开数据集的下载和管理
import torchvision.transforms as transforms #公开数据集的预处理库,格式转换
import torchvision.utils as utils
import torch.utils.data as data_utils #对数据集进行分批加载的工具集
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
Hello World 1.8.0 False
#2-1 准备数据集
train_data = dataset.MNIST(root = "mnist",
train = True,
transform = transforms.ToTensor(),
download = True)
#2-1 准备数据集
test_data = dataset.MNIST(root = "mnist",
train = False,
transform = transforms.ToTensor(),
download = True)
print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Dataset MNIST Number of datapoints: 60000 Root location: mnist Split: Train StandardTransform Transform: ToTensor() size= 60000 Dataset MNIST Number of datapoints: 10000 Root location: mnist Split: Test StandardTransform Transform: ToTensor() size= 10000
2.4 对样本数据预处理
(1)原图不叠加噪声显示
#原图不叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)
print("\\n不叠加噪声, 原图显示")
plt.imshow(image)
plt.show()
原始图片 torch image shape: torch.Size([1, 28, 28]) torch image label: 5 单通道原始图片:numpy numpy image shape: (28, 28, 1) numpy image label: 5 不叠加噪声, 原图显示
(2)原图叠加噪声
#原图叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)
print("\\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean
plt.imshow(image)
plt.show()
原始图片 torch image shape: torch.Size([1, 28, 28]) torch image label: 5 单通道原始图片:numpy numpy image shape: (28, 28, 1) numpy image label: 5 叠加噪声, 平滑显示
(3)#叠加噪声,灰度显示图片
#叠加噪声,灰度显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)
print("\\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean
plt.imshow(image)
plt.show()
原始图片 torch image shape: torch.Size([1, 28, 28]) torch image label: 5 三通道灰度图片:torch torch image shape: torch.Size([3, 28, 28]) torch image label: 5 三通道灰度图片:numpy numpy image shape: (28, 28, 3) numpy image label: 5 叠加噪声, 平滑显示
(4)#不叠加噪声,黑白显示图片
#不叠加噪声,黑白显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)
print("\\n不叠加噪声,黑白显示")
plt.imshow(image)
plt.show()
print("numpy image shape:", image.shape)
原始图片 torch image shape: torch.Size([1, 28, 28]) torch image label: 5 三通道灰度图片:torch torch image shape: torch.Size([3, 28, 28]) torch image label: 5 三通道灰度图片:numpy numpy image shape: (28, 28, 3) numpy image label: 5 不叠加噪声,黑白显示
2.5 批量数据读取与显示
(1)batch批量图片的读取
# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
batch_size = 64,
shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
batch_size = 64,
shuffle = True)
print(train_loader)
print(test_loader)
print(len(train_loader), len(train_data)/64)
print(len(test_loader), len(test_data)/64)
<torch.utils.data.dataloader.DataLoader object at 0x000002461EF4A1C0> <torch.utils.data.dataloader.DataLoader object at 0x000002461ED66610> 938 937.5 157 156.25
(2)一个batch图片的显示
显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])
print("\\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs)
print(images.shape)
print(labels.shape)
print("\\n转换成imshow格式")
images = images.numpy().transpose(1,2,0)
print(images.shape)
print(labels.shape)
print("\\n显示样本标签")
#打印图片标签
for i in range(64):
print(labels[i], end=" ")
i += 1
#换行
if i%8 == 0:
print(end='\\n')
print("\\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片 torch.Size([64, 1, 28, 28]) torch.Size([64]) 64 合并成一张三通道灰度图片 torch.Size([3, 242, 242]) torch.Size([64]) 转换成imshow格式 (242, 242, 3) torch.Size([64]) 显示样本标签 tensor(0) tensor(8) tensor(3) tensor(7) tensor(5) tensor(7) tensor(9) tensor(7) tensor(1) tensor(1) tensor(1) tensor(8) tensor(8) tensor(6) tensor(0) tensor(1) tensor(4) tensor(8) tensor(1) tensor(3) tensor(3) tensor(6) tensor(4) tensor(4) tensor(0) tensor(5) tensor(8) tensor(5) tensor(9) tensor(3) tensor(7) tensor(5) tensor(2) tensor(1) tensor(0) tensor(6) tensor(8) tensor(8) tensor(9) tensor(6) tensor(1) tensor(3) tensor(5) tensor(3) tensor(4) tensor(4) tensor(3) tensor(1) tensor(4) tensor(1) tensor(4) tensor(4) tensor(9) tensor(8) tensor(7) tensor(2) tensor(3) tensor(1) tensor(2) tensor(0) tensor(8) tensor(1) tensor(1) tensor(4) 显示图片
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489
以上是关于[Pytorch系列-33]:数据集 - torchvision与MNIST数据集的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-36]:数据集 - torchvision与ImageNet数据集
[Pytorch系列-37]: 工具集 - torchvision库详解(数据集数据预处理模型)