[基于Pytorch的MNIST识别02]用户数据集的读取
Posted AIplusX
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[基于Pytorch的MNIST识别02]用户数据集的读取相关的知识,希望对你有一定的参考价值。
写在前面
pytorch包含了很多包括mnist在内的开源数据集,但是如果要建立自己的神经网络的话肯定需要训练自己的数据集,那么如何利用pytorch加载用户自己的数据集呢?今天就来解决这个问题。
今天的工作
需要加载用户自己的数据需要继承pytorch的Dataset类,并且重载其中的__getitem__方法和__len__方法,前者是告诉pytorch如何根据索引来获取对应的样例,后者是告诉pytorch数据集的大小,在这里先放一下我的继承实现:
class UserMNIST(Dataset):
def __init__(self, root='', train=True, transform=None, target_transform=None):
super(UserMNIST, self).__init__()
self.train = train #type of datasets
self.transform = transform
self.target_transform = target_transform
self.train_nums = int(6e4)
self.test_ratio = int(9e-1)
self.validate_nums = int(1e4)
# print(self.train_nums)#The scientific counting method is float,which should be changed by user
#load files path
if self.train :
self.imgs_folder_path = train_imgs_path
self.labels_folder_path = train_labels_path
self.img_nums = self.train_nums
else:
self.imgs_folder_path = validate_imgs_path
self.labels_folder_path = validate_labels_path
self.img_nums = self.validate_nums
#load dataset
with open(self.imgs_folder_path, 'rb') as _imgs:
self._train_images = _imgs.read()
with open(self.labels_folder_path, 'rb') as _labs:
self._train_labels = _labs.read()
def __getitem__(self, index):
image = self.getImages(self._train_images, index)
label = self.getLabels(self._train_labels, index)
return image,label
def __len__(self):
return self.img_nums
def getImages(self, image, index):
img_size_bit = struct.calcsize('>784B')
start_index = struct.calcsize('>IIII') + index * img_size_bit
temp = struct.unpack_from('>784B', image, start_index)
img = np.reshape(temp, (28, 28))
# img = self.normalization(np.array(temp, dtype=float))
# print(img)
return img
def getLabels(self, label, index):
lab_size_bit = struct.calcsize('>1B')
start_index = struct.calcsize('>II') + index * lab_size_bit
lab = struct.unpack_from('>1B', label, start_index)
return lab
def normalization(self, x):
max = float(255)
min = float(0)
for i in range(0, 784):
x[i] = (x[i] - min) / (max - min)
return x
usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type
对我这个类进行测试,测试程序如下所示(输出结果附后):
dataiter = iter(usermnist_train_loader)
images,labels = dataiter.next()
print(images.shape, labels)
print(type(images), type(labels))
plt.imshow(images[0])
plt.show()
img, lab = usermnist_train.__getitem__(6) # get the 34th sample
print(type(img))
print(type(lab))
plt.imshow(img)
plt.show()
其实关于这个类的重载最重要的是其使用方法,也就是这2个语句:
usermnist_train = UserMNIST(train=True)#how to one data,return numpy(user define) type
usermnist_train_loader = DataLoader(dataset=usermnist_train, batch_size=user_batch_size, shuffle=False)#do somethings to get all data,return tensor type
UserMNIST是继承了Dataset类的实现,从测试样例之中就可以看出来正确继承了,其实例化之后的usermnist_train做为参数通过DataLoader进行数据加载。
我的理解是DataLoader类是一种数据集加载优化的方法,因为如果用户有大量的数据需要加载的话会很花时间,DataLoader类可以通过多线程加载进行提速,而且可以打乱数据集的顺序增加随机性,使得训练集环境更接近实际环境等优点。
以上是关于[基于Pytorch的MNIST识别02]用户数据集的读取的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)
图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)