[Pytorch系列-33]:数据集 - torchvision与MNIST数据集

Posted 文火冰糖的硅基工坊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-33]:数据集 - torchvision与MNIST数据集相关的知识,希望对你有一定的参考价值。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 MNIST数据集

2.1 MNIST数据集介绍

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

2.4 对样本数据预处理

2.5 批量数据读取与显示


第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • STL10

第2章 MNIST数据集

2.1 MNIST数据集介绍

MNIST数据集http://yann.lecun.com/exdb/

 备注 :可以先把样本数据下载本地,以提升程序调试的效率。最终的产品可以远程下载数据。

  • 每张图片大小:28*28.
  • 单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

(1)操作函数MNIST()的解读

MNIST (root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  • root : 文件存放路的根路径,下载的文件存放在该路径下,processed/training.pt 和 processed/test.pt 的主目录
  • train : True = 训练集, False = 测试集
  • target_transform:导入数据时,是否需要对数据格式进行转换,一个函数,原始图片作为输入,返回一个转换后的图片。有时候神经网络所需要的尺寸与数据集提供的尺寸不一致,则可以通过此方法进行转换。
  • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。

(2)代码实例

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
Hello World
1.8.0
False

#2-1 准备数据集
train_data = dataset.MNIST(root = "mnist",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "mnist",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Dataset MNIST
    Number of datapoints: 60000
    Root location: mnist
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 60000

Dataset MNIST
    Number of datapoints: 10000
    Root location: mnist
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

2.4 对样本数据预处理

(1)原图不叠加噪声显示

#原图不叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\\n不叠加噪声, 原图显示")

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

不叠加噪声, 原图显示

(2)原图叠加噪声

#原图叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

叠加噪声, 平滑显示

 

(3)#叠加噪声,灰度显示图片

#叠加噪声,灰度显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

叠加噪声, 平滑显示

(4)#不叠加噪声,黑白显示图片

#不叠加噪声,黑白显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\\n不叠加噪声,黑白显示")
plt.imshow(image)
plt.show()
print("numpy image shape:", image.shape)
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

不叠加噪声,黑白显示

2.5 批量数据读取与显示

(1)batch批量图片的读取

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 64,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 64,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_loader), len(train_data)/64)
print(len(test_loader),  len(test_data)/64)
<torch.utils.data.dataloader.DataLoader object at 0x000002461EF4A1C0>
<torch.utils.data.dataloader.DataLoader object at 0x000002461ED66610>
938 937.5
157 156.25

(2)一个batch图片的显示

显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs)
print(images.shape)
print(labels.shape)

print("\\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\\n显示样本标签")
#打印图片标签
for i in range(64):
    print(labels[i], end=" ")
    i += 1
    #换行
    if i%8 == 0:
        print(end='\\n')

print("\\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([64, 1, 28, 28])
torch.Size([64])
64

合并成一张三通道灰度图片
torch.Size([3, 242, 242])
torch.Size([64])

转换成imshow格式
(242, 242, 3)
torch.Size([64])

显示样本标签
tensor(0) tensor(8) tensor(3) tensor(7) tensor(5) tensor(7) tensor(9) tensor(7) 
tensor(1) tensor(1) tensor(1) tensor(8) tensor(8) tensor(6) tensor(0) tensor(1) 
tensor(4) tensor(8) tensor(1) tensor(3) tensor(3) tensor(6) tensor(4) tensor(4) 
tensor(0) tensor(5) tensor(8) tensor(5) tensor(9) tensor(3) tensor(7) tensor(5) 
tensor(2) tensor(1) tensor(0) tensor(6) tensor(8) tensor(8) tensor(9) tensor(6) 
tensor(1) tensor(3) tensor(5) tensor(3) tensor(4) tensor(4) tensor(3) tensor(1) 
tensor(4) tensor(1) tensor(4) tensor(4) tensor(9) tensor(8) tensor(7) tensor(2) 
tensor(3) tensor(1) tensor(2) tensor(0) tensor(8) tensor(1) tensor(1) tensor(4) 

显示图片



作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489

以上是关于[Pytorch系列-33]:数据集 - torchvision与MNIST数据集的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch: 张量的拼接切分索引

[Pytorch系列-36]:数据集 - torchvision与ImageNet数据集

[Pytorch系列-37]: 工具集 - torchvision库详解(数据集数据预处理模型)

如何创建图神经网络数据集? (pytorch 几何)

[Pytorch系列-45]:卷积神经网络 - 用GPU训练AlexNet+CIFAR10数据集

[Pytorch系列-46]:卷积神经网络 - 用GPU训练ResNet+CIFAR100数据集