python读取MNIST image数据
Posted jude_python
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python读取MNIST image数据相关的知识,希望对你有一定的参考价值。
import numpy as np
import struct
def loadImageSet(which=0):
print "load image set"
binfile=None
if which==0:
binfile = open("..//dataset//train-images-idx3-ubyte", ‘rb‘)
else:
binfile= open("..//dataset//t10k-images-idx3-ubyte", ‘rb‘)
buffers = binfile.read()
head = struct.unpack_from(‘>IIII‘ , buffers ,0)
print "head,",head
offset=struct.calcsize(‘>IIII‘)
imgNum=head[1]
width=head[2]
height=head[3]
#[60000]*28*28
bits=imgNum*width*height
bitsString=‘>‘+str(bits)+‘B‘ #like ‘>47040000B‘
imgs=struct.unpack_from(bitsString,buffers,offset)
binfile.close()
imgs=np.reshape(imgs,[imgNum,width,height])
print "load imgs finished"
return imgs
def loadLabelSet(which=0):
print "load label set"
binfile=None
if which==0:
binfile = open("..//dataset//train-labels-idx1-ubyte", ‘rb‘)
else:
binfile= open("..//dataset//t10k-labels-idx1-ubyte", ‘rb‘)
buffers = binfile.read()
head = struct.unpack_from(‘>II‘ , buffers ,0)
print "head,",head
imgNum=head[1]
offset = struct.calcsize(‘>II‘)
numString=‘>‘+str(imgNum)+"B"
labels= struct.unpack_from(numString , buffers , offset)
binfile.close()
labels=np.reshape(labels,[imgNum,1])
#print labels
print ‘load label finished‘
return labels
if __name__=="__main__":
imgs=loadImageSet()
#import PlotUtil as pu
#pu.showImgMatrix(imgs[0])
loadLabelSet()
及方便训练的reader
import numpy as np
import struct
import gzip
import cPickle
class MnistReader():
def __init__(self,mnist_path,data_dim=1,one_hot=True):
‘‘‘
mnist_path: the path of mnist.pkl.gz
data_dim=1 [N,784]
data_dim=3 [N,28,28,1]
one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true
‘‘‘
self.mnist_path=mnist_path
self.data_dim=data_dim
self.one_hot=one_hot
self.load_minist(mnist_path)
self.train_datalabel=zip(self.train_x,self.train_y)
self.valid_datalabel=zip(self.valid_x,self.valid_y)
self.batch_offset_train=0
def next_batch_train(self,batch_size):
‘‘‘
return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
and list of labels with shape [N] or [N,10] dependents on self.one_hot
‘‘‘
if self.batch_offset_train<len(self.train_datalabel)//batch_size:
imgs=list();labels=list()
for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]:
if self.data_dim==3:
d=np.reshape(d, [28,28,1])
imgs.append(d)
if self.one_hot:
a=np.zeros(10)
a[l]=1
labels.append(l)
else:
labels.append(l)
self.batch_offset_train+=1
return imgs,labels
else:
self.batch_offset_train=0
np.random.shuffle(self.train_datalabel)
return self.next_batch_train(batch_size)
def next_batch_val(self,batch_size):
‘‘‘
return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
and list of labels with shape [N,1] or [N,10] dependents on self.one_hot
‘‘‘
np.random.shuffle(self.valid_datalabel)
imgs=list();labels=list()
for d,l in self.train_datalabel[0:batch_size]:
if self.data_dim==3:
d=np.reshape(d, [28,28,1])
imgs.append(d)
if self.one_hot:
a=np.zeros(10)
a[l]=1
labels.append(l)
else:
labels.append(l)
return imgs,labels
def load_minist(self,dataset):
print "load dataset"
f = gzip.open(dataset, ‘rb‘)
train_set, valid_set, test_set = cPickle.load(f)
f.close()
self.train_x,self.train_y=train_set
self.valid_x,self.valid_y=valid_set
self.test_x , self.test_y=test_set
print "train image,label shape:",self.train_x.shape,self.train_y.shape
print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape
print "test image,label shape:",self.test_x.shape,self.test_y.shape
print "load dataset end"
if __name__=="__main__":
mnist=MnistReader(‘../dataset/mnist.pkl.gz‘,data_dim=3)
data,label=mnist.next_batch_train(batch_size=1)
print data
print label
第三种加载方式需要 gzip和struct
import gzip, struct
def _read(image,label):
minist_dir = ‘your_dir/‘
with gzip.open(minist_dir+label) as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label = np.fromstring(flbl.read(), dtype=np.int8)
with gzip.open(minist_dir+image, ‘rb‘) as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
return image,label
def get_data():
train_img,train_label = _read(
‘train-images-idx3-ubyte.gz‘,
‘train-labels-idx1-ubyte.gz‘)
test_img,test_label = _read(
‘t10k-images-idx3-ubyte.gz‘,
‘t10k-labels-idx1-ubyte.gz‘)
return [train_img,train_label,test_img,test_label]
以上是关于python读取MNIST image数据的主要内容,如果未能解决你的问题,请参考以下文章