21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型
Posted helloworld0604
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型相关的知识,希望对你有一定的参考价值。
书籍源码:https://github.com/hzy46/Deep-Learning-21-Examples
CNN的发展已经很多了,ImageNet引发的一系列方法,LeNet,GoogLeNet,VGGNet,ResNet每个方法都有很多版本的衍生,tensorflow中带有封装好各方法和网络的函数,只要喂食自己的训练集就可以完成自己的模型,感觉超方便!!!激动!!!因为虽然原理流程了解了,但是要写出来真的。。。。好难,臣妾做不到啊~~~~~~~~
START~~~~
1.数据准备
首先了解下微调的概念: 以VGG为例
他的结构是卷积+全连接,卷积层分为5个部分共13层,conv1~conv5。还有三层全连接,即fc6,fc7,fc8。总共16层,因此被称为VGG16。
a.如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8,因为fc8原本的输出是1000类的概率。需要改为符合自身训练集的输出类别数。
b.训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。因为已经训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,这样做不仅节约大量训练时间,而且有助于分类器性能的提高。
载入VGG16的参数后,即可开始训练。此时需要指定训练层数的范围。一般而言,可以选择以下几种:
- 只训练fc8:训练范围一定要包含fc8这一层。这样的选择一般性能都不会太好,但速度很快;因为他只训练fc8,保持其他层的参数不动,相当于把VGG16当成一个“特征提取器”,用fc7层提取的特征做一个softmax的模型分类。
- 训练所有参数:耗时较慢,但能取得较高性能。
- 训练部分参数:通常是固定浅层参数不变,训练深层参数。如固定conv1、conv2部分的参数不训练,只训练conv3、conv4、conv5、fc6、fc7、fc8的参数。
这种训练方法就是对神经网络做微调。
1.1 切分train&test
书中提供了卫星图像数据集,有6个类别,分别是森林(wood),水域(water),岩石(rock),农田(wetland),冰川(glacier),城市区域(urban)
保存结构为data_prepare/pic,再下层有两个文件夹train和validation,各文件夹下有6个文件夹,放的是该类别下的图片。
1.2 转换成tfrecord格式
python data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite
参数解释:
-t pic/ :表示转换pic文件夹下的数据,该文件夹必须与上面的文件结构保持一致
--train-shards 2 :把训练集分成两块,即最后的训练数据就是两个tfrecord格式的文件。若数据集更大,可以分更多数据块
--validation-shards 2 :把验证集分成两块
--num-thread 2 :用两个线程来产生数据。注意线程数必须要能整除train-shards和validation-shards,来保证每个线程处理的数据块是相同的。
--dataset-name :给生成的数据集起个名字,即表示最后生成文件的开头是satellite_train和satellite_validation
运行上述命令后,就可以在 pic 文件夹中找到 5 个新生成的文件 ,分别是:
- 训练数据 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,
- 验证数据 satellite_validation_00000-of-00002.tfrecord、satellite_validation_00001-of-00002.tfrecord。
- label.txt 它表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序 。 如图片在 tfrecord 中的标签为 0 ,那么就对应 label.txt 第一行的类别,在 tfrecord的标签为1,就对应 label.txt 中第二行的类别,依此类推。
2.训练模型
2.1 TensorFlow Slim
Google 公司公布的一个图像分类工具包,它不仅定义了一些方便的接口,还提供了很多 ImageNet 数据集上常用的网络结构和预训练模型 。
截至2017年7月,Slim 提供包括 VGG16、VGG19、Inception V1 ~ V4、ResNet 50、ResNet 101、MobileNet 在内大多数常用模型的结构以及预训练模型,更多的模型还会被持续添加进来。
源码地址: https://github.com/tensorflow/models/tree/master/research/slim
可以通过执行 git clone https://github.corn/tensorflow/models.git 来获取
2.2 定义新的datasets文件<修改slim源码>
2.3 准备训练文件夹
2.4 开始训练
3.验证准确率
4.导出模型并对单张图片分类
THE END~~~~
以上是关于21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型的主要内容,如果未能解决你的问题,请参考以下文章
《21个项目玩转深度学习:基于TensorFlow的实践详解》高清带标签PDF版本学习下载
分享《21个项目玩转深度学习:基于TensorFlow的实践详解》PDF+源代码
分享《21个项目玩转深度学习:基于TensorFlow的实践详解》PDF+源代码
分享《21个项目玩转深度学习:基于TensorFlow的实践详解》+PDF+源码+何之源