python读取MNIST image数据

Posted jude_python

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python读取MNIST image数据相关的知识,希望对你有一定的参考价值。

Lecun Mnist数据集下载

import numpy as np
import struct

def loadImageSet(which=0):
    print "load image set"
    binfile=None
    if which==0:
        binfile = open("..//dataset//train-images-idx3-ubyte", ‘rb‘)
    else:
        binfile=  open("..//dataset//t10k-images-idx3-ubyte", ‘rb‘)
    buffers = binfile.read()

    head = struct.unpack_from(‘>IIII‘ , buffers ,0)
    print "head,",head

    offset=struct.calcsize(‘>IIII‘)
    imgNum=head[1]
    width=head[2]
    height=head[3]
    #[60000]*28*28
    bits=imgNum*width*height
    bitsString=‘>‘+str(bits)+‘B‘ #like ‘>47040000B‘

    imgs=struct.unpack_from(bitsString,buffers,offset)

    binfile.close()
    imgs=np.reshape(imgs,[imgNum,width,height])
    print "load imgs finished"
    return imgs

def loadLabelSet(which=0):
    print "load label set"
    binfile=None
    if which==0:
        binfile = open("..//dataset//train-labels-idx1-ubyte", ‘rb‘)
    else:
        binfile=  open("..//dataset//t10k-labels-idx1-ubyte", ‘rb‘)
    buffers = binfile.read()

    head = struct.unpack_from(‘>II‘ , buffers ,0)
    print "head,",head
    imgNum=head[1]

    offset = struct.calcsize(‘>II‘)
    numString=‘>‘+str(imgNum)+"B"
    labels= struct.unpack_from(numString , buffers , offset)
    binfile.close()
    labels=np.reshape(labels,[imgNum,1])

    #print labels
    print ‘load label finished‘
    return labels

if __name__=="__main__":
    imgs=loadImageSet()
    #import PlotUtil as pu
    #pu.showImgMatrix(imgs[0])
    loadLabelSet()

及方便训练的reader

import numpy as np
import struct
import gzip
import cPickle

class MnistReader():

    def __init__(self,mnist_path,data_dim=1,one_hot=True):
        ‘‘‘
        mnist_path: the path of mnist.pkl.gz
        data_dim=1 [N,784]
        data_dim=3 [N,28,28,1]
        one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true
        ‘‘‘
        self.mnist_path=mnist_path
        self.data_dim=data_dim
        self.one_hot=one_hot
        self.load_minist(mnist_path)

        self.train_datalabel=zip(self.train_x,self.train_y)
        self.valid_datalabel=zip(self.valid_x,self.valid_y)

        self.batch_offset_train=0

    def next_batch_train(self,batch_size):
        ‘‘‘
        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
               and list of labels with shape [N] or [N,10] dependents on self.one_hot
        ‘‘‘
        if self.batch_offset_train<len(self.train_datalabel)//batch_size:
            imgs=list();labels=list()
            for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]:
                if self.data_dim==3:
                    d=np.reshape(d, [28,28,1])
                imgs.append(d)
                if self.one_hot:
                    a=np.zeros(10)
                    a[l]=1
                    labels.append(l)
                else:
                    labels.append(l)
            self.batch_offset_train+=1
            return imgs,labels
        else:
            self.batch_offset_train=0
            np.random.shuffle(self.train_datalabel)
            return self.next_batch_train(batch_size)

    def next_batch_val(self,batch_size):
        ‘‘‘
        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
               and list of labels with shape [N,1] or [N,10] dependents on self.one_hot
        ‘‘‘
        np.random.shuffle(self.valid_datalabel)
        imgs=list();labels=list()
        for d,l in self.train_datalabel[0:batch_size]:
            if self.data_dim==3:
                d=np.reshape(d, [28,28,1])
            imgs.append(d)
            if self.one_hot:
                a=np.zeros(10)
                a[l]=1
                labels.append(l)
            else:
                labels.append(l)
        return imgs,labels

    def load_minist(self,dataset):
        print "load dataset"
        f = gzip.open(dataset, ‘rb‘)
        train_set, valid_set, test_set = cPickle.load(f)
        f.close()
        self.train_x,self.train_y=train_set
        self.valid_x,self.valid_y=valid_set
        self.test_x , self.test_y=test_set
        print "train image,label shape:",self.train_x.shape,self.train_y.shape
        print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape
        print "test  image,label shape:",self.test_x.shape,self.test_y.shape
        print "load dataset end"

if __name__=="__main__":
    mnist=MnistReader(‘../dataset/mnist.pkl.gz‘,data_dim=3)
    data,label=mnist.next_batch_train(batch_size=1)
    print data
    print label

第三种加载方式需要 gzip和struct

import gzip, struct

def _read(image,label):
    minist_dir = ‘your_dir/‘
    with gzip.open(minist_dir+label) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(minist_dir+image, ‘rb‘) as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return image,label

def get_data():
    train_img,train_label = _read(
            ‘train-images-idx3-ubyte.gz‘, 
            ‘train-labels-idx1-ubyte.gz‘)
    test_img,test_label = _read(
            ‘t10k-images-idx3-ubyte.gz‘, 
            ‘t10k-labels-idx1-ubyte.gz‘)
    return [train_img,train_label,test_img,test_label]

以上是关于python读取MNIST image数据的主要内容,如果未能解决你的问题,请参考以下文章

MNIST手写数字数据集

Python读取MNIST数据集

通过MATLAB读取mnist数据库

mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

python读取mnist label数据库

手写数字识别——基于全连接层和MNIST数据集