TensorFlow2 入门指南 | 07 数据集的加载预处理数据增强

Posted AI 菌

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TensorFlow2 入门指南 | 07 数据集的加载预处理数据增强相关的知识,希望对你有一定的参考价值。

前言:

本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地、高效地入门TensorFlow2 深度学习框架。如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大动力!

系列文章汇总:TensorFlow2 入门指南
Github项目地址:https://github.com/Keyird/TensorFlow2-for-beginner


一、数据集介绍与加载

在TensorFlow2中,内置了一些经典的数据集,可以通过简单的命令来加载这些数据集。这一小节,主要介绍一下在图像领域的一些经典数据集,以及加载这些数据集的方法。

(1)mnist

为了方便业界统一测试和评估算法, Lecun等人在1998年发布了手写数字图片数据集,命名为 MNIST,它包含了 0~9 共 10 种数字的手写图片,每种数字一共有 7000 张图片,采集自不同书写风格的真实手写图片,一共 70000 张图片。其中60000张图片作为训练集,用来训练模型。10000张图片作为测试集,用来训练或者预测。训练集和测试集共同组成了整个 MNIST 数据集。MNIST数据集中的每张图片,大小为28 × \\times × 28,同时只保留灰度信息(即单通道)。下图是MNIST数据集中的部分图片:

在这里插入图片描述
TensorFlow2 加载 mnist 数据集:

(x,y), (x_test, y_test) = keras.datasets.mnist.load_data()

(2)fashion-mnist

Fashion MNIST是一个定位在比 MNIST 识别问题更复杂的数据集,它的设定与 MNIST 几乎完全一样,包含了 10 类不同类型的衣服、鞋子、包等灰度图片,图片大小为28x28,共 70000 张图片,其中 60000 张用于训练集,10000 张用于测试集,如图下图所示,每行对应一种类别:

在这里插入图片描述
TensorFlow2 加载 fashion-mnist 数据集:

(x,y), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

(3)cifar-10/100

CIFAR10/100 数据集由加拿大 Institute For Advanced Research 机构发布,它包含了飞机、汽车、鸟、猫等共 10 大类,100个小类物体的彩色图片(10个大类中,每个类包含10个小类),共 6万 张图片。其中 5万 作为训练数据集,1万作为测试数据集。种类样片如下图所示:

在这里插入图片描述
CIFAR10/100官方网址:https://www.cs.toronto.edu/~kriz/cifar.html

TensorFlow2 加载 CIFAR10 数据集:

(x,y), (x_test, y_test) = keras.datasets.cifar10.load_data()

TensorFlow2 加载 CIFAR100 数据集:

(x,y), (x_test, y_test) = keras.datasets.cifar100.load_data()

二、数据集准备

在数据集输入网络之前,通常要将下载好的数据集以一定的格式打包成张量的形式,然后按照一定的batch送入网络中。其中,batch表示的是一次送入神经网络的图片张数。下面,我们以mnist数据集为例,在送入神经网络之前,进行一系列的操作。

(1)将数据集打包成张量

db = tf.data.Dataset.from_trnsor_slices((x, y))
db_test = tf.data.Dataset.from_trnsor_slices((x_test, y_test))

(2)数据预处理

通常,数据集中图片的像素值都是在0-255之间,需要将其归一化到0-1之间。对于分类任务,标签y一般需要将其处理成one-hot编码的形式。

def preprocess(x, y):
    """ 数据集预处理 """
    x = tf.cast(x, dtype=tf.float32)/255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y

(3)数据集准备

通过以下方式,对数据集进行预处理、打乱、设置每次送入网路的数据量batch,以及网络迭代训练的epoch数。

# 对数据集分别进行preprocess中的预处理、将数据集中60000张图片顺序打乱、batch设置为32,epoch设置为10
db = db.map(preprocess).shuffle(60000).batch(32).repeat(10)

三、数据增强

所谓数据增强,就是通过对原图进行旋转、缩放、平移、裁剪、改变视角、遮挡某局部区域等操作,而不改变图片的类别标签,从而达到扩充数据集的效果。数据增强是解决模型过拟合的重要方式之一,至于模型的过拟合,这在后面的文章中会将具体讲到。

这一小节,主要来学习TensorFlow2中的数据增强的操作。TensorFlow2 中提供了常用图片的处理函数,位于 tf.image 子模块中。我们一般将数据增强实现在预处理函数 preprocess 中,这样方便后续统一对图片进行数据增强操作。

(1)tf.image.resize

通过 tf.image.resize(x, [w, h]),可将输入的图片缩放成 wxh 大小的图片:

def preprocess(x,y):
    """ x: 图片的路径,y:图片的数字编码 """
    x = tf.io.read_file(x)  # 读图片
    x = tf.image.decode_jpeg(x, channels=3)  # RGBA
    x = tf.image.resize(x, [244, 244])  # 图片缩放到 244x244 
    y = tf.one_hot(y, depth=10)  # one_hot编码
    return x, y

然后,通过上一节的方式九可以对数据集统一进行resize操作:

db = tf.data.Dataset.from_trnsor_slices((x, y))
db = db.map(preprocess)

(2) tf.image.random_crop

通过在原图的左右或者上下方向去掉部分边缘像素(随机裁剪),可以保持图片主体不变,同时获得新的图片样本。在实际裁剪时,一般先将图片缩放到略大于网络输入尺寸的大小,再进行裁剪到合适大小,例如网络的输入大小为 224x224,那么我们先通过 resize 函数将图片缩放到 244x244 大小,再通过 tf.image.random_crop 随机裁剪到 224x224 大小。实现如下:

# 图片先缩放到稍大尺寸
y = tf.image.resize(x, [244, 244])
# 再随机裁剪到合适尺寸
y = tf.image.random_crop(x, [224, 224, 3])

(3)tf.image.random_flip_left_right

通过 tf.image.random_flip_left_right(x) 可实现图片在水平方向的随机翻转:

y = tf.image.random_flip_left_right(x)

(4)tf.image.random_flip_up_down

通过 tf.image.random_flip_up_down(x) 可实现图片在竖直方向的随机翻转:

y = tf.image.random_flip_up_down(x)

(5)image.rgb_to_hsv

通过 image.rgb_to_hsv(x) 可将原图的颜色空间由 RGB 转向 HSV:

y = image.rgb_to_hsv(x) 

除了HSV颜色空间外,还可以转变为YIQ、YUV格式,方法分别如下:

y = tf.image.rgb_to_yiq(x)
y = tf.image.rgb_to_yuv(x)

(6)tf.image.central_crop

通过 tf.image.central_crop(x, central_fraction) 实现对原图中心裁剪:

y = tf.image.central_crop(x, 0.5)

其中,参数 0.5 表示对图像中间一半的区域进行裁剪。


本教程所有代码会逐渐上传github仓库:https://github.com/Keyird/TensorFlow2-for-beginner
如果对你有帮助的话,欢迎star收藏~

最好的关系是互相成就,各位的「三连」就是【AI 菌】创作的最大动力,我们下期见!

在这里插入图片描述

以上是关于TensorFlow2 入门指南 | 07 数据集的加载预处理数据增强的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow2快速入门- MNIST 数据集详解

物体检测快速入门系列 | 01-基于Tensorflow2.x Object Detection API构建自定义物体检测器

不平衡图像数据集 (Tensorflow2)

Tensorflow1 和 Tensorflow2 中的批处理

Tensorflow2.0的手写数字识别系统(Mnist数据集)

ResNet实战:tensorflow2.X版本,ResNet50图像分类任务(大数据集)