Python提取 MNIST 数据集中的图片到本地
Posted Xavier Jiezou
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Python提取 MNIST 数据集中的图片到本地相关的知识,希望对你有一定的参考价值。
引言 | Introduction
MNIST 数据集是最经典的一个机器学习的数据集,常被视为图像分类问题的入门级数据。虽然 Python 的很多第三方包都对其进行了封装,但对于模型训练来说,我们常用的还是本地的数据。今天教大家如何提取 MNIST 数据到本地。
安装 | Install
pip install torchvision==0.11.2 tqdm==4.54.1
方法 | Method
我们利用 torchvision
包封装的 MNIST 数据集来提取图片到本地。MNIST 数据集是一个典型的多分类数据集,其中存放的是 7 万张手写数字的灰度图片(6 万训练和1 万测试),每张灰度图片的大小是 28×28。共有 10 类标签,分别对应数字 0-9。
一般来说,我们回将整个数据集按照 8:1:1 的比例(或其他比例)划分为 3 个子集:训练集,验证集和测试集。MNIST 官方只划分了两个子集,笔者自认为不太合理,故提取到本地时没有单独的创建 train 和 test 子文件夹来存放图片,不过在图片名称前加了 train 和 test 字样,以标识该图片是从哪个子数据集中获取的。
下方给出了具体的实现代码。您需要安装两个第三方 Python 包:torchvision 和 tqdm,然后给定数据在本地保存的文件夹路径,运行代码即可。
代码 | Code
import os
import shutil
from tqdm import tqdm
from torchvision import datasets
from concurrent.futures import ThreadPoolExecutor
def mnist_export(root: str = './data/minst'):
"""Export MNIST data to a local folder using multi-threading.
Args:
root (str, optional): Path to local folder. Defaults to './data/minst'.
"""
for i in range(10):
os.makedirs(os.path.join(root, f'./i'), exist_ok=True)
split_list = ['train', 'test']
data =
split: datasets.MNIST(
root='./tmp',
train=split == 'train',
download=True
) for split in split_list
total = sum([len(data[split]) for split in split_list])
with tqdm(total=total) as pbar:
with ThreadPoolExecutor() as tp:
for split in split_list:
for index, (image, label) in enumerate(data[split]):
tmp = os.path.join(root, f'label/split_index.png')
tp.submit(image.save, tmp).add_done_callback(
lambda func: pbar.update()
)
shutil.rmtree('./tmp')
if __name__ == '__main__':
mnist_export('./data/minst')
参考 | References
下载 | Download
https://cdn.jsdelivr.net/gh/XavierJiezou/pytorch-lstm-examples@main/data/mnist.7z
以上是关于Python提取 MNIST 数据集中的图片到本地的主要内容,如果未能解决你的问题,请参考以下文章