Deeplab v3 : 源码训练和测试
Posted 明天去哪
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Deeplab v3 : 源码训练和测试相关的知识,希望对你有一定的参考价值。
本文主要介绍根据github tensorflow/models中官方代码来训练deeplab v3+
源代码: https://github.com/tensorflow/models/tree/master/research/deeplab
配置deeplab v3
- Clone 源代码, https://github.com/tensorflow/models.git
- 根据官方文档进行安装,https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/installation.md
这里有几个需要注意的地方:
(1) cuda 9.0 & tensorflow 1.6以上版本
(2) 需要将models/research/slim路径导入到PYTHONPATH环境变量中,这个是因为deeplab中的一些工具比如multigrid使用的是slim中实现的
export PYTHONPATH=$PYTHONPATH:/path-to/models/research/slim
(3) 然后进行测试就可以得到一个结果,测试需要在models/research/下进行,这样很不方便,想要在models/research/deeplab下进行,可以在models/research/deeplab/model_test.py中导入deeplab模块就可以了,一个例子如下:
import sys
sys.path.append('/path-to/models/research')
测试成功,则表示基本配置成功,可以进行训练的配置了.
训练
参考官方文档: https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/cityscapes.md
以Cityscapes为例进行训练
- 将Cityscapes数据转为tfrecord,使用models/research/deeplab/datasets下的脚本: convert_cityscapes.sh、build_cityscapes_data.py、build_data.py脚本,这个需要改一下convert_cityscapes.sh中的一些路径,基本没有什么坑。这三个文件比较简单,可以读一下,之后做数据集基本就基于这三个文件了.
- 使用脚本进行训练
CHECKPOINT_PATH='/path-to/models/research/deeplab/initial-checkpoint/xception/model.ckpt'
TRAIN_DIR_PATH='/path-tomodels/research/deeplab/train_dir'
CITYSCAPES_PATH='/path-to/cityscapes/tfrecord'
python train.py \\
--logtostderr \\
--training_number_of_steps=90000 \\
--train_split="train" \\
--model_variant="xception_65" \\
--atrous_rates=6 \\
--atrous_rates=12 \\
--atrous_rates=18 \\
--output_stride=16 \\
--decoder_output_stride=4 \\
--train_crop_size=769 \\
--train_crop_size=769 \\
--train_batch_size=1 \\
--dataset="cityscapes" \\
--tf_initial_checkpoint="$CHECKPOINT_PATH" \\
--train_logdir="$TRAIN_DIR_PATH" \\
--dataset_dir="$CITYSCAPES_PATH"
- 评估/测试
CHECKPOINT_PATH='/path-to/models/research/deeplab/initial-checkpoint/deeplabv3_cityscapes_train'
EVAL_DIR_PATH='/path-to/models/research/deeplab/eval_dir'
CITYSCAPES_PATH='/path-to/cityscapes/tfrecord'
python eval.py \\
--logtostderr \\
--eval_split="val" \\
--model_variant="xception_65" \\
--atrous_rates=6 \\
--atrous_rates=12 \\
--atrous_rates=18 \\
--output_stride=16 \\
--decoder_output_stride=4 \\
--eval_crop_size=1025 \\
--eval_crop_size=2049 \\
--dataset="cityscapes" \\
--checkpoint_dir="$CHECKPOINT_PATH" \\
--eval_logdir="$EVAL_DIR_PATH" \\
--dataset_dir="$CITYSCAPES_PATH"
注意,如果使用官方提供的checkpoint,压缩包中是没有checkpoint文件的,需要手动添加一个checkpoint文件
4. 性能
根据官方提供的checkpoint
(1) official-deeplabv3+, tensorflow。eval OS: 16, scale: [1.0]
miou: 0.787332237
(2) official-deeplabv3+, tensorflow。eval OS: 16, scale: [0.75:0.25:1.75]
miou: 0.806650937
注意
由于是第一次跑tf,免不了有很多的坑
1. tf默认占用所有GPU的所有计算资源,通常可能如果只想使用其中一个,或者即可,可以在需要执行的脚本前面加上: CUDA_VISIBLE_DEVICES=gpu_id,即可
2. 如果使用官方提供的checkpoint,压缩包中是没有checkpoint文件的,需要手动添加一个checkpoint文件
以上是关于Deeplab v3 : 源码训练和测试的主要内容,如果未能解决你的问题,请参考以下文章
MATLAB深度学习采用 Deeplab v3+ 实现全景分割
在对 Cityscapes 语义分割数据集进行 deeplab v3+ 训练时遇到错误