[基于Pytorch的MNIST识别02]用户数据集的读取

Posted AIplusX

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[基于Pytorch的MNIST识别02]用户数据集的读取相关的知识,希望对你有一定的参考价值。

写在前面

pytorch包含了很多包括mnist在内的开源数据集,但是如果要建立自己的神经网络的话肯定需要训练自己的数据集,那么如何利用pytorch加载用户自己的数据集呢?今天就来解决这个问题。

今天的工作

需要加载用户自己的数据需要继承pytorch的Dataset类,并且重载其中的__getitem__方法和__len__方法,前者是告诉pytorch如何根据索引来获取对应的样例,后者是告诉pytorch数据集的大小,在这里先放一下我的继承实现:

class UserMNIST(Dataset):
    def __init__(self, root='', train=True, transform=None, target_transform=None):
        super(UserMNIST, self).__init__()
        self.train = train #type of datasets
        self.transform = transform
        self.target_transform = target_transform
        self.train_nums = int(6e4)
        self.test_ratio = int(9e-1)
        self.validate_nums = int(1e4)
#         print(self.train_nums)#The scientific counting method is float,which should be changed by user
        
        #load files path
        if self.train :
            self.imgs_folder_path = train_imgs_path
            self.labels_folder_path = train_labels_path
            self.img_nums = self.train_nums
        else:
            self.imgs_folder_path = validate_imgs_path
            self.labels_folder_path = validate_labels_path
            self.img_nums = self.validate_nums
        
        #load dataset
        with open(self.imgs_folder_path, 'rb') as _imgs:
            self._train_images = _imgs.read()
        with open(self.labels_folder_path, 'rb') as _labs:
            self._train_labels = _labs.read()
            
            
    def __getitem__(self, index):
        image = self.getImages(self._train_images, index)
        label = self.getLabels(self._train_labels, index)
        return image,label
    
    def __len__(self):
        return self.img_nums
    
    def getImages(self, image, index):
        img_size_bit = struct.calcsize('>784B')
        start_index = struct.calcsize('>IIII') + index * img_size_bit
        temp = struct.unpack_from('>784B', image, start_index)
        img = np.reshape(temp, (28, 28))
#         img = self.normalization(np.array(temp, dtype=float))
#         print(img)
        return img

    def getLabels(self, label, index):
        lab_size_bit = struct.calcsize('>1B')
        start_index = struct.calcsize('>II') + index * lab_size_bit
        lab = struct.unpack_from('>1B', label, start_index)
        return lab
    
    def normalization(self, x):
        max = float(255)
        min = float(0)
        for i in range(0, 784):
            x[i] = (x[i] - min) / (max - min)
        return x

usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type

对我这个类进行测试,测试程序如下所示(输出结果附后):

dataiter = iter(usermnist_train_loader)
images,labels = dataiter.next()
print(images.shape, labels)
print(type(images), type(labels))
plt.imshow(images[0])
plt.show()

img, lab = usermnist_train.__getitem__(6) # get the 34th sample
print(type(img))
print(type(lab))
plt.imshow(img)
plt.show()

其实关于这个类的重载最重要的是其使用方法,也就是这2个语句:

usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type

UserMNIST是继承了Dataset类的实现,从测试样例之中就可以看出来正确继承了,其实例化之后的usermnist_train做为参数通过DataLoader进行数据加载。

我的理解是DataLoader类是一种数据集加载优化的方法,因为如果用户有大量的数据需要加载的话会很花时间,DataLoader类可以通过多线程加载进行提速,而且可以打乱数据集的顺序增加随机性,使得训练集环境更接近实际环境等优点。

以上是关于[基于Pytorch的MNIST识别02]用户数据集的读取的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)

[基于Pytorch的MNIST识别03]运行模型

[基于Pytorch的MNIST识别05]总结

基于PyTorch实现MNIST手写字识别