MNIST读取出错RuntimeError Dataset not found.You can download解决方案

Posted herosunly

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了MNIST读取出错RuntimeError Dataset not found.You can download解决方案相关的知识,希望对你有一定的参考价值。

1. 前言

  Pytorch官网教程中,第一个程序是使用简单神经网络对Fashion MNIST数据进行学习和预测,而机器学习/深度学习的处理流程的第一步是:读取数据。代码如下所示:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

  然而不幸的是,由于国情导致的网络问题,往往会导致在下载过程中出错,所以出现如下问题:RuntimeError: Dataset not found. You can use download=True to download it很多平台和博客提供的解决方案并不完美而且对新手并不友好(不说明逻辑和原因)
在这里插入图片描述
  那该如何解决呢?

2. 下载数据

  可通过迅雷或者其他下载工具对下列4个数据文件进行下载:

3. 修改代码

3.1 修改逻辑

  首先,我们需要得到修改的代码所处的位置,根据Python系列课程之模块的内容,包具有一个特殊的属性: __path__ ,简单来说也就是包所处的具体路径。

import torchvision
print(torchvision.__path__)
['/home/anaconda3/lib/python3.6/site-packages/torchvision']

在这里插入图片描述
  可以得到mnist.py的具体路径为上述路径下的子路径/datasets/mnist.py。阅读其关键代码可知,download=True不仅会下载.gz文件(图像+标签),而且会将其保存成torch格式的.pt文件。而我们下载的文件只是.gz文件,所以需要通过代码将.gz转换成.pt。

  为了不影响之前的参数和处理逻辑,所以在对象初始化增加了一个参数:load_gz_files和对应的处理函数load_gz_files()。load_gz_files()会借用utils.py中的extract_archive()函数和check_integrity()。

def load_gz_files(self):
    """Load the .gz format MNIST data if it exist ."""
    if self._check_exists():
        return

    if not os.path.exists(self.processed_folder):
        makedir_exist_ok(self.processed_folder)

    for url, md5 in self.resources:
        filename = url.rpartition('/')[2]
        fpath = os.path.join(self.raw_folder, filename)
        if check_integrity(fpath, md5):
            extract_archive(from_path=fpath, to_path=self.processed_folder, remove_finished=False)

    training_set = (
        read_image_file(os.path.join(self.processed_folder, 'train-images-idx3-ubyte')),
        read_label_file(os.path.join(self.processed_folder, 'train-labels-idx1-ubyte'))
    )
    test_set = (
        read_image_file(os.path.join(self.processed_folder, 't10k-images-idx3-ubyte')),
        read_label_file(os.path.join(self.processed_folder, 't10k-labels-idx1-ubyte'))
    )
    with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
        torch.save(training_set, f)
    with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
        torch.save(test_set, f)

    print('Done!')

  为了避免load_gz_files()函数和原有的download()函数互相影响,所以简单修改了下代码(完整代码在文章最后,使用的时候将其复制并覆盖mnist.py文件内容即可):

if load_gz_files:
    self.load_gz_files()
else:
    if download:
        self.download()

3.2 代码使用

  修改后代码如何使用呢?

  1. 新建文件夹,如/home/data/FashionMNIST/raw,并把下载的四个.gz文件放入其中。
  2. 读取文件代码如下(需要注意的是root路径是新建文件夹的根路径):
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="/home/data/",
    train=True,
    download=False,
    transform=ToTensor(),
    load_gz_files=True
)

test_data = datasets.FashionMNIST(
    root="/home/data/",
    train=False,
    download=False,
    transform=ToTensor(),
    load_gz_files=True
)

  如果在from .utils import语句中报错:ImportError: cannot import name ‘makedir_exist_ok’,如下所示:
在这里插入图片描述
  只须在utils.py中添加小段函数代码:

def makedir_exist_ok(dirpath):
    """
    Python2 support for os.makedirs(.., exist_ok=True)
    """
    try:
        os.makedirs(dirpath)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise

3.3 附录:mnist.py完整代码

  为了方便大家修改,所以提供源码如下:

from __future__ import print_function
from .vision import VisionDataset
import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
import codecs
from .utils import download_url, download_and_extract_archive, extract_archive, \\
    makedir_exist_ok, verify_str_arg, check_integrity


class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``MNIST/processed/training.pt``
            and  ``MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    resources = [
        ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
    ]

    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False, load_gz_files=False):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.train = train  # training set or test set

        if load_gz_files:
            self.load_gz_files()
        else:
            if download:
                self.download()
        
        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.processed_folder,
                                            self.training_file)) and
                os.path.exists(os.path.join(self.processed_folder,
                                            self.test_file)))

    def load_gz_files(self):
        """Load the .gz format MNIST data if it exist ."""
        if self._check_exists():
            return

        if not os.path.exists(self.processed_folder):
            makedir_exist_ok(self.processed_folder)
    
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            fpath = os.path.join(self.raw_folder, filename)
            if check_integrity(fpath, md5):
                extract_archive(from_path=fpath, to_path=self.processed_folder, remove_finished=False)

        training_set = (
            read_image_file(os.path.join(self.processed_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.processed_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.processed_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.processed_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists():
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")


class FashionMNIST(MNIST):
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt``
            and  ``Fashion-MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    resources = [
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
         "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
         "25c81989df183df01b3e8a0aad5dffbe"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
         "bef4ecab320f06d8554ea6380940ec79"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
         "bb300cfdad3c16e7a12a480ee83cd310")
    ]
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


class KMNIST(MNIST):
    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``KMNIST/processed/training.pt``
            and  ``KMNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    resources = [
        ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
        ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
        ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
        ("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134")
    ]
    classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']


class EMNIST(MNIST):
    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``EMNIST/processed/training.pt``
            and  ``EMNIST/processed/test.pt`` exist.
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    # Updated URL from https://www.nist.gov/node/1298471/emnist-dataset since the
    # _official_ download link
    # https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download
    # is (currently) unavailable
    url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
    splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

    def __init__(self, root, split, **kwargs):
        self.split = verify_str_arg(split, "split", self.splits)
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
        super(EMNIST, self).__init__(root, **kwargs)

    @staticmethod
    def _training_file(split):
        return 'training_{}.pt'.format(split)

    @staticmethod
    def _test_file(split):
        return 'test_{}.pt'.format(split)

    def download(self):
        Python报错RuntimeError: DataLoader worker (pid(s) 9764, 15128) exited unexpectedly

RuntimeError:创建新内容类型时出错

MNIST、torchvision 中的输出和广播形状不匹配

尝试运行在 C++ 中读取的文件时出错

Keras载入mnist数据集出错问题解决方案

解决tensor.sub_(mean).div_(std)RuntimeError: output with shape [1, 28, 28] doesn‘t match the broadcas