CIFAR-10和python读取

Posted 精诚所至 金石为开

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CIFAR-10和python读取相关的知识,希望对你有一定的参考价值。

1、CIFAR-10,是一个用于做图像分类研究的数据集。

  • 由60000个图片组成
  • 6万个图片中,5万张用于训练,1万张用于测试
  • 每个图片是32x32像素
  • 所有图片可以分成10类
  • 每个图片都有一个标签,标记属于哪一个类
  • 测试集中一个类对应1000张图
  • 训练集中将5万张图分为5份
  • 类之间的图片是互斥的,不存在类别重叠的情况

 

下图展示了具体的分类, 

 

2、 数据集加载:

CIFAR-10提供了三个版本的数据格式:python,matlab,二进制 。

这里以python的加载为例,参考http://cs231n.github.io/assignments2018/assignment1/

 

from __future__ import print_function

from six.moves import cPickle as pickle
import numpy as np
import os
from scipy.misc import imread
import platform

#读取文件
def load_pickle(f):
    version = platform.python_version_tuple() # 取python版本号
    if version[0] == \'2\':
        return  pickle.load(f) # pickle.load, 反序列化为python的数据类型
    elif version[0] == \'3\':
        return  pickle.load(f, encoding=\'latin1\')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, \'rb\') as f:
    datadict = load_pickle(f)   # dict类型
    X = datadict[\'data\']        # X, ndarray, 像素值
    Y = datadict[\'labels\']      # Y, list, 标签, 分类
    
    # reshape, 一维数组转为矩阵10000行3列。每个entries是32x32
    # transpose,转置
    # astype,复制,同时指定类型
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = [] # list
  ys = []
  
  # 训练集batch 1~5
  for b in range(1,6):
    f = os.path.join(ROOT, \'data_batch_%d\' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X) # 在list尾部添加对象X, x = [..., [X]]
    ys.append(Y)    
  Xtr = np.concatenate(xs) # [ndarray, ndarray] 合并为一个ndarray
  Ytr = np.concatenate(ys)
  del X, Y

  # 测试集
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, \'test_batch\'))
  return Xtr, Ytr, Xte, Yte

batch数据反序列化出来是

{

  \'data\': 像素数据,

    \'labels\':分类标签

}

 

其中涉及到的python基础:

 1、from __future__ import print_function, __future__是用于在老版本python中使用新版本特性

 2、from six.moves import cPickle as pickle, 是序列化和反序列化库,pickle.load,反序列化为python的数据类型

 3、list的append方法,在list尾部添加对象,不需要和之前的数据类型一致

 4、numpy的concatenate,合并array

 

Reference:

 http://www.cs.toronto.edu/~kriz/cifar.html

 http://cs231n.github.io/assignments2018/assignment1/

以上是关于CIFAR-10和python读取的主要内容,如果未能解决你的问题,请参考以下文章

CIFAR-10 图像识别

Tensorflow机器学习入门——cifar10数据集的读取展示与保存

ResNet18迁移学习CIFAR10分类任务(附python代码)

利用Tensorflow读取二进制CIFAR-10数据集

Python读入CIFAR-10数据库

TFRecord文件的读写