Python提取 MNIST 数据集中的图片到本地

Posted Xavier Jiezou

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Python提取 MNIST 数据集中的图片到本地相关的知识,希望对你有一定的参考价值。

引言 | Introduction

MNIST 数据集是最经典的一个机器学习的数据集,常被视为图像分类问题的入门级数据。虽然 Python 的很多第三方包都对其进行了封装,但对于模型训练来说,我们常用的还是本地的数据。今天教大家如何提取 MNIST 数据到本地。

安装 | Install

pip install torchvision==0.11.2 tqdm==4.54.1

方法 | Method

我们利用 torchvision 包封装的 MNIST 数据集来提取图片到本地。MNIST 数据集是一个典型的多分类数据集,其中存放的是 7 万张手写数字的灰度图片(6 万训练和1 万测试),每张灰度图片的大小是 28×28。共有 10 类标签,分别对应数字 0-9。

一般来说,我们回将整个数据集按照 8:1:1 的比例(或其他比例)划分为 3 个子集:训练集,验证集和测试集。MNIST 官方只划分了两个子集,笔者自认为不太合理,故提取到本地时没有单独的创建 train 和 test 子文件夹来存放图片,不过在图片名称前加了 train 和 test 字样,以标识该图片是从哪个子数据集中获取的。

下方给出了具体的实现代码。您需要安装两个第三方 Python 包:torchvision 和 tqdm,然后给定数据在本地保存的文件夹路径,运行代码即可。

代码 | Code

import os
import shutil
from tqdm import tqdm
from torchvision import datasets
from concurrent.futures import ThreadPoolExecutor


def mnist_export(root: str = './data/minst'):
    """Export MNIST data to a local folder using multi-threading.

    Args:
        root (str, optional): Path to local folder. Defaults to './data/minst'.
    """
    for i in range(10):
        os.makedirs(os.path.join(root, f'./i'), exist_ok=True)
    split_list = ['train', 'test']
    data = 
        split: datasets.MNIST(
            root='./tmp',
            train=split == 'train',
            download=True
        ) for split in split_list
    
    total = sum([len(data[split]) for split in split_list])
    with tqdm(total=total) as pbar:
        with ThreadPoolExecutor() as tp:
            for split in split_list:
                for index, (image, label) in enumerate(data[split]):
                    tmp = os.path.join(root, f'label/split_index.png')
                    tp.submit(image.save, tmp).add_done_callback(
                        lambda func: pbar.update()
                    )
    shutil.rmtree('./tmp')


if __name__ == '__main__':
    mnist_export('./data/minst')

参考 | References

下载 | Download

https://cdn.jsdelivr.net/gh/XavierJiezou/pytorch-lstm-examples@main/data/mnist.7z

以上是关于Python提取 MNIST 数据集中的图片到本地的主要内容,如果未能解决你的问题,请参考以下文章

如何显示mnist中的数据(tensroflow)

全连接网络基础1——MNIST数据集

基于MNIST数据集实现手写数字识别

MNIST机器学习入门

python怎么在一群点集中,提取中心坐标

Python读取MNIST数据集