实现pytorch版efficientdet的全过程

Posted cx-99

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了实现pytorch版efficientdet的全过程相关的知识,希望对你有一定的参考价值。

一、安装环境

# install requirements
pip install pycocotools numpy opencv-python tqdm tensorboard tensorboardX pyyaml
pip install torch==1.4.0
pip install torchvision==0.5.0

二、下载pytorch版efficientdet源码

git clone https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch.git

源码链接:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch

三、准备数据集

# your dataset structure should be like this
datasets/
    -your_project_name/
        -train_set_name/
            -*.jpg
        -val_set_name/
            -*.jpg
        -annotations
            -instances_{train_set_name}.json
            -instances_{val_set_name}.json

# for example, coco2017
datasets/
    -coco2017/
        -train2017/
            -000000000001.jpg
            -000000000002.jpg
            -000000000003.jpg
        -val2017/
            -000000000004.jpg
            -000000000005.jpg
            -000000000006.jpg
        -annotations
            -instances_train2017.json
            -instances_val2017.json

四、修改配置文件

# create a yml file {your_project_name}.yml under projectsfolder 
# modify it following coco.yml
 
# for example
project_name: coco
train_set: train2017
val_set: val2017
num_gpus: 4  # 0 means using cpu, 1-N means using gpus 

# mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

# this is coco anchors, change it if necessary
anchors_scales: [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
anchors_ratios: [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]

# objects from all labels from your dataset with the order from your annotations.
# its index must match your datasets category_id.
# category_id is one_indexed,
# for example, index of car here is 2, while category_id of is 3
obj_list: [person, bicycle, car, ...]

五、训练coco数据集

# train efficientdet-d1 on a custom dataset 
# with batchsize 8 and learning rate 1e-5

python train.py -c 1 -p your_project_name --batch_size 8 --lr 1e-5

六、训练带已训练好的权重的数据集

# train efficientdet-d2 on a custom dataset with pretrained weights
# with batchsize 8 and learning rate 1e-5 for 10 epoches

python train.py -c 2 -p your_project_name --batch_size 8 --lr 1e-5 --num_epochs 10  --load_weights /path/to/your/weights/efficientdet-d2.pth

# with a coco-pretrained, you can even freeze the backbone and train heads only
# to speed up training and help convergence.

python train.py -c 2 -p your_project_name --batch_size 8 --lr 1e-5 --num_epochs 10  --load_weights /path/to/your/weights/efficientdet-d2.pth  --head_only True

权重下载链接:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/

七、尽早停止训练

# while training, press Ctrl+c, the program will catch KeyboardInterrupt
# and stop training, save current checkpoint.

八、恢复训练

# let says you started a training session like this.

python train.py -c 2 -p your_project_name --batch_size 8 --lr 1e-5  --load_weights /path/to/your/weights/efficientdet-d2.pth  --head_only True
 
# then you stopped it with a Ctrl+c, it exited with a checkpoint

# now you want to resume training from the last checkpoint
# simply set load_weights to last

python train.py -c 2 -p your_project_name --batch_size 8 --lr 1e-5  --load_weights last  --head_only True

九、评估模型性能

# eval on your_project, efficientdet-d5

python coco_eval.py -p your_project_name -c 5  -w /path/to/your/weights

十、调试训练(可选)

# when you get bad result, you need to debug the training result.
python train.py -c 2 -p your_project_name --batch_size 8 --lr 1e-5 --debug True

# then checkout test/ folder, there you can visualize the predicted boxes during training
# dont panic if you see countless of error boxes, it happens when the training is at early stage.
# But if you still cant see a normal box after several epoches, not even one in all image,
# then its possible that either the anchors config is inappropriate or the ground truth is corrupted.

十一、个人训练总结

最重要的是不放弃!!!

遇到错误,根据错误来源看代码,当然一般按照流程来不会出错。

以上是关于实现pytorch版efficientdet的全过程的主要内容,如果未能解决你的问题,请参考以下文章

《PyTorch 版 EfficientDet 比官方 TF 实现快 25 倍?》

[Pytorch系列-21]:Pytorch基础 - 反向链式求导的全过程拆解

Qt5.7 实现Https 认证全过程解析(亲自动手版)

一文搞懂 Flink 处理 Barrier 全过程

图卷积神经网络(GCN)综述与实现(PyTorch版)

基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络softmax)基础版