YOLOv5源码逐行超详细注释与解读——训练部分train.py

Posted 路人贾'ω'

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了YOLOv5源码逐行超详细注释与解读——训练部分train.py相关的知识,希望对你有一定的参考价值。

前言

本篇文章主要是对YOLOv5项目的训练部分train.py。通常这个文件主要是用来读取用户自己的数据集,加载模型并训练。

文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑!

友情提示:全文近5万字,可以先点收藏☆再慢慢看哦~

源码下载地址:mirrors / ultralytics / yolov5 · GitCode


YOLOv5源码逐行讲解系列前期回顾:  

YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析
YOLOv5源码逐行超详细注释与解读(2)——推理部分detect.py


目录

前言

目录

🚀一、导包和基本配置

1.1 Usage

1.2 导入安装好的python库

1.3 获取当前文件的绝对路径

1.4 加载自定义模块

1.5 分布式训练初始化

🚀二、执行main()函数

2.1 检查分布式训练环境

2.2 判断是否断点续训

2.3 判断是否分布式训练

2.4  判断是否进化训练

🚀三、设置opt参数

🚀四、执行train()函数

4.1 加载参数和初始化配置信息

4.1.1 载入参数

4.1.2 创建训练权重目录和保存路径

4.1.3 读取超参数配置文件

4.1.4 设置参数的保存路径

4.1.5 加载日志信息

4.1.6 加载其它参数

4.2 加载网络模型

4.2.1 加载预训练模型

4.2.2 设置冻结层

 4.2.3 设置优化器

4.2.4 设置学习率

4.2.5 训练前最后准备

4.3 加载数据集

4.3.1 创建数据集

4.3.2 计算anchor

4.4 训练过程

4.4.1 初始化训练需要的模型参数

4.4.2 训练热身部分

4.4.3 开始训练

4.4.4 训练完成保存模型

4.5 打印信息并释放显存

🚀五、执行run()函数

🚀六、train.py代码全部注释


🚀一、导包和基本配置

1.1 Usage

"""
Train a YOLOv5 model on a custom dataset
在数据集上训练 yolo v5 模型
Usage:
    $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
    训练数据为coco128 coco128数据集中有128张图片 80个类别,是规模较小的数据集
"""

这里是开头作者注释的一个部分,意在说明一些项目基本情况。
第一行表示我们用的模型是YOLOv5

第二行表示我们传入的data数据集是coco128数据集,有128张图片,80个类别,使用的权重模型是yolov5s模型,–img表示图片大小640。

1.2 导入安装好的python库

'''======================1.导入安装好的python库====================='''
import argparse  # 解析命令行参数模块
import math  # 数学公式模块
import os  # 与操作系统进行交互的模块 包含文件路径操作和解析
import random  # 生成随机数模块
import sys  # sys系统模块 包含了与Python解释器和它的环境有关的函数
import time   # 时间模块 更底层
from copy import deepcopy  # 深度拷贝模块
from datetime import datetime  # datetime模块能以更方便的格式显示日期或对日期进行运算。
from pathlib import Path  # Path将str转换为Path对象 使字符串路径易于操作的模块

import numpy as np  # numpy数组操作模块
import torch # 引入torch
import torch.distributed as dist  # 分布式训练模块
import torch.nn as nn  # 对torch.nn.functional的类的封装 有很多和torch.nn.functional相同的函数
import yaml  # yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件。
from torch.cuda import amp  # PyTorch amp自动混合精度训练模块
from torch.nn.parallel import DistributedDataParallel as DDP  # 多卡训练模块
from torch.optim import SGD, Adam, lr_scheduler   # tensorboard模块
from tqdm import tqdm  # 进度条模块

首先,导入一下常用的python库

  • argparse:  它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息
  • math:  调用这个库进行数学运算
  • os: 它提供了多种操作系统的接口。通过os模块提供的操作系统接口,我们可以对操作系统里文件、终端、进程等进行操作
  • random:  是使用随机数的Python标准库。random库主要用于生成随机数
  • sys: 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 Python 运行时环境的不同部分
  • time: Python中处理时间的标准库,是最基础的时间处理库
  • copy:  Python 中赋值语句不复制对象,而是在目标和对象之间创建绑定 (bindings) 关系。copy模块提供了通用的浅层复制和深层复制操作
  • datetime:  是Python常用的一个库,主要用于时间解析和计算
  • pathlib:  这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读

然后再导入一些 pytorch库

  • numpy:  科学计算库,提供了矩阵,线性代数,傅立叶变换等等的解决方案, 最常用的是它的N维数组对象
  • torch:   这是主要的Pytorch库。它提供了构建、训练和评估神经网络的工具
  • torch.distributed:  torch.distributed包提供Pytorch支持和通信基元,对多进程并行,在一个或多个机器上运行的若干个计算阶段
  • torch.nn:  torch下包含用于搭建神经网络的modules和可用于继承的类的一个子包
  • yaml:  yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件
  • torch.cuda.amp:  自动混合精度训练 —— 节省显存并加快推理速度
  • torch.nn.parallel:  构建分布式模型,并行加速程度更高,且支持多节点多gpu的硬件拓扑结构
  • torch.optim:  优化器 Optimizer。主要是在模型训练阶段对模型可学习参数进行更新,常用优化器有 SGD,RMSprop,Adam等
  • tqdm:  就是我们看到的训练时进度条显示

1.3 获取当前文件的绝对路径

'''===================2.获取当前文件的绝对路径========================'''
FILE = Path(__file__).resolve()  # __file__指的是当前文件(即train.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/train.py
ROOT = FILE.parents[0]  # YOLOv5 root directory  ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path:  # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径
    sys.path.append(str(ROOT))  # add ROOT to PATH  把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative ROOT设置为相对路径

这段代码会获取当前文件的绝对路径,并使用Path库将其转换为Path对象。

这一部分的主要作用有两个:

  • 将当前项目添加到系统路径上,以使得项目中的模块可以调用。
  • 将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件。

1.4 加载自定义模块

'''===================3..加载自定义模块============================'''
import val  # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, NCOLS, check_dataset, check_file, check_git_status, check_img_size,
                           check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
                           init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
                           one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first

这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:

  • val:  这个是测试集,我们下一篇再具体讲
  • models.experimental:  实验性质的代码,包括MixConv2d、跨层权重Sum等
  • models.yolo:  yolo的特定模块,包括BaseModel,DetectionModel,ClassificationModel,parse_model等
  • utils.autoanchor:  定义了自动生成锚框的方法
  • utils.autobatch:  定义了自动生成批量大小的方法
  • utils.callbacks:  定义了回调函数,主要为logger服务
  • utils.datasets:  dateset和dateloader定义代码
  • utils.downloads:  谷歌云盘内容下载
  • utils.general.py:   定义了一些常用的工具函数,比如检查文件是否存在、检查图像大小是否符合要求、打印命令行参数等等
  • utils.loggers :  日志打印
  • utils.loss:  存放各种损失函数
  • utils.metrics:   模型验证指标,包括ap,混淆矩阵等
  • utils.plots.py:    定义了Annotator类,可以在图像上绘制矩形框和标注信息
  • utils.torch_utils.py:   定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等

通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。


1.5 分布式训练初始化

'''================4.分布式训练初始化==========================='''
# https://pytorch.org/docs/stable/elastic/run.html该网址有详细介绍
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # -本地序号。这个 Worker 是这台机器上的第几个 Worker
RANK = int(os.getenv('RANK', -1))  # -进程序号。这个 Worker 是全局第几个 Worker
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))  # 总共有几个 Worker
'''
   查找名为LOCAL_RANK,RANK,WORLD_SIZE的环境变量,
   若存在则返回环境变量的值,若不存在则返回第二个参数(-1,默认None)
rank和local_rank的区别: 两者的区别在于前者用于进程间通讯,后者用于本地设备分配。
'''

接下来是设置分布式训练时所需的环境变量。分布式训练指的是多GPU训练,将训练参数分布在多个GPU上进行训练,有利于提升训练效率。


🚀二、执行main()函数

2.1 检查分布式训练环境

def main(opt, callbacks=Callbacks()):
    '''
    2.1  检查分布式训练环境
    '''
    # Checks
    if RANK in [-1, 0]:  # 若进程编号为-1或0
        # 输出所有训练参数 / 参数以彩色的方式表现
        print_args(FILE.stem, opt)
        # 检测YOLO v5的github仓库是否更新,若已更新,给出提示
        check_git_status()
        # 检查requirements.txt所需包是否都满足
        check_requirements(exclude=['thop'])

这段代码主要是检查分布式训练的环境

RANK为-1或0,会执行下面三行代码,打印参数并检查github仓库和依赖库。

  • 第一行代码,负责打印文件所用到的参数信息,这个参数包括命令行传入进去的参数以及默认参数
  • 第二行代码,检查yolov5的github仓库是否更新,如果更新的话,会有一个提示
  • 第三行代码,检查requirements中要求的安装包有没有正确安装成功,没有成功的话会给予一定的提示

2.2 判断是否断点续训

    '''
    2.2  判断是否断点续训
    '''
    # Resume
    if opt.resume and not check_wandb_resume(opt) and not opt.evolve:  # resume an interrupted run
        # isinstance()是否是已经知道的类型
        # 如果resume是True,则通过get_lastest_run()函数找到runs为文件夹中最近的权重文件last.pt
        ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path
        # 判断是否为文件,若不是文件抛出异常
        assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
        # opt.yaml是训练时的命令行参数文件
        with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
            # 超参数替换,将训练时的命令行参数加载进opt参数对象中
            opt = argparse.Namespace(**yaml.safe_load(f))  # replace
        # opt.cfg设置为'' 对应着train函数里面的操作(加载权重时是否加载权重里的anchor)
        opt.cfg, opt.weights, opt.resume = '', ckpt, True  # reinstate
        # 打印从ckpt恢复断点训练信息
        LOGGER.info(f'Resuming training from ckpt')
    else:
        # 不使用断点续训,就从文件中读取相关参数
        # check_file (utils/general.py)的作用为查找/下载文件 并返回该文件的路径。
        opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \\
            check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project)  # checks
        # 如果模型文件和权重文件为空,弹出警告
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
        # 如果要进行超参数进化,重建保存路径
        if opt.evolve:
            # 设置新的项目输出目录
            opt.project = str(ROOT / 'runs/evolve')
            # 将resume传递给exist_ok
            opt.exist_ok, opt.resume = opt.resume, False  # pass resume to exist_ok and disable resume
        # 根据opt.project生成目录,并赋值给opt.save_dir  如: runs/train/exp1
        opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))

这段代码主要是关于断点训练的判断和准备

断点训练是当训练异常终止或想调节超参数时,系统会保留训练异常终止前的超参数与训练参数,当下次训练开始时,并不会从头开始,而是从上次中断的地方继续训练。 

  • 使用断点续训,就从last.pt中读取相关参数
  • 不使用断点续训,就从文件中读取相关参数

2.3 判断是否分布式训练

    '''
    2.3  判断是否分布式训练
    '''
    # DDP mode -->  支持多机多卡、分布式训练
    # 选择程序装载的位置
    device = select_device(opt.device, batch_size=opt.batch_size)
    # 当进程内的GPU编号不为-1时,才会进入DDP
    if LOCAL_RANK != -1:
        #  用于DDP训练的GPU数量不足
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
        # WORLD_SIZE表示全局的进程数
        assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
        # 不能使用图片采样策略
        assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
        # 不能使用超参数进化
        assert not opt.evolve, '--evolve argument is not compatible with DDP training'
        # 设置装载程序设备
        torch.cuda.set_device(LOCAL_RANK)
        # 保存装载程序的设备
        device = torch.device('cuda', LOCAL_RANK)
        # torch.distributed是用于多GPU训练的模块
        dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")

 这段代码主要是检查DDP训练的配置,并设置GPU。

DDP(Distributed Data Parallel)用于单机或多机的多GPU分布式训练,但目前DDP只能在Linux下使用。这部分它会选择你是使用cpu还是gpu,假如你采用的是分布式训练的话,它就会额外执行下面的一些操作,我们这里一般不会用到分布式,所以也就没有执行什么东西。


2.4  判断是否进化训练

    '''
    2.4  判断是否进化训练
    '''
    # Train 训练模式: 如果不进行超参数进化,则直接调用train()函数,开始训练
    if not opt.evolve:# 如果不使用超参数进化
        # 开始训练
        train(opt.hyp, opt, device, callbacks)
        if WORLD_SIZE > 1 and RANK == 0:
            # 如果全局进程数大于1并且RANK等于0
            # 日志输出 销毁进程组
            LOGGER.info('Destroying process group... ')
            # 训练完毕,销毁所有进程
            dist.destroy_process_group()

这段代码是不进行进化训练的情况,此时正常训练。

如果输入evolve会执行else下面这些代码,因为我们没有输入evolve并且不是分布式训练,因此会执行train函数。也就是说,当不使用超参数进化训练时,直接把命令行参数传入train函数,训练完成后销毁所有进程。

接下来我们再看看使用超参数进化训练的情况:


    # Evolve hyperparameters (optional) 遗传进化算法,边进化边训练
    else:
        # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
        # 超参数列表(突变范围 - 最小值 - 最大值)
        meta = 'lr0': (1, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)
                'lrf': (1, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
                'momentum': (0.3, 0.6, 0.98),  # SGD momentum/Adam beta1
                'weight_decay': (1, 0.0, 0.001),  # optimizer weight decay
                'warmup_epochs': (1, 0.0, 5.0),  # warmup epochs (fractions ok)
                'warmup_momentum': (1, 0.0, 0.95),  # warmup initial momentum
                'warmup_bias_lr': (1, 0.0, 0.2),  # warmup initial bias lr
                'box': (1, 0.02, 0.2),  # box loss gain
                'cls': (1, 0.2, 4.0),  # cls loss gain
                'cls_pw': (1, 0.5, 2.0),  # cls BCELoss positive_weight
                'obj': (1, 0.2, 4.0),  # obj loss gain (scale with pixels)
                'obj_pw': (1, 0.5, 2.0),  # obj BCELoss positive_weight
                'iou_t': (0, 0.1, 0.7),  # IoU training threshold
                'anchor_t': (1, 2.0, 8.0),  # anchor-multiple threshold
                'anchors': (2, 2.0, 10.0),  # anchors per output grid (0 to ignore)
                'fl_gamma': (0, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)
                'hsv_h': (1, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)
                'hsv_s': (1, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
                'hsv_v': (1, 0.0, 0.9),  # image HSV-Value augmentation (fraction)
                'degrees': (1, 0.0, 45.0),  # image rotation (+/- deg)
                'translate': (1, 0.0, 0.9),  # image translation (+/- fraction)
                'scale': (1, 0.0, 0.9),  # image scale (+/- gain)
                'shear': (1, 0.0, 10.0),  # image shear (+/- deg)
                'perspective': (0, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
                'flipud': (1, 0.0, 1.0),  # image flip up-down (probability)
                'fliplr': (0, 0.0, 1.0),  # image flip left-right (probability)
                'mosaic': (1, 0.0, 1.0),  # image mixup (probability)
                'mixup': (1, 0.0, 1.0),  # image mixup (probability)
                'copy_paste': (1, 0.0, 1.0)  # segment copy-paste (probability)
        # 加载默认超参数
        with open(opt.hyp, errors='ignore') as f:
            hyp = yaml.safe_load(f)  # load hyps dict
            # 如果超参数文件中没有'anchors',则设为3
            if 'anchors' not in hyp:  # anchors commented in hyp.yaml
                hyp['anchors'] = 3
        # 使用进化算法时,仅在最后的epoch测试和保存
        opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir)  # only val/save final epoch
        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices
        evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
        if opt.bucket:
            os.system(f'gsutil cp gs://opt.bucket/evolve.csv save_dir')  # download evolve.csv if exists

            """
            遗传算法调参:遵循适者生存、优胜劣汰的法则,即寻优过程中保留有用的,去除无用的。
            遗传算法需要提前设置4个参数: 群体大小/进化代数/交叉概率/变异概率
            """
        

这段代码是使用超参数进化训练的前期准备

首先指定每个超参数的突变范围、最大值、最小值,再为超参数的结果保存做好准备。

# 选择超参数的遗传迭代次数 默认为迭代300次
        for _ in range(opt.evolve):  # generations to evolve
            # 如果evolve.csv文件存在
            if evolve_csv.exists():  # if evolve.csv exists: select best hyps and mutate
                # Select parent(s)
                # 选择超参进化方式,只用single和weighted两种
                parent = 'single'  # parent selection method: 'single' or 'weighted'
                # 加载evolve.txt
                x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
                # 选取至多前五次进化的结果
                n = min(5, len(x))  # number of previous results to consider
                # fitness()为x前四项加权 [P, R, mAP@0.5, mAP@0.5:0.95]
                # np.argsort只能从小到大排序, 添加负号实现从大到小排序, 算是排序的一个代码技巧
                x = x[np.argsort(-fitness(x))][:n]  # top n mutations
                # 根据(mp, mr, map50, map)的加权和来作为权重计算hyp权重
                w = fitness(x) - fitness(x).min() + 1E-6  # weights (sum > 0)
                # 根据不同进化方式获得base hyp
                if parent == 'single' or len(x) == 1:
                    # 根据权重的几率随机挑选适应度历史前5的其中一个
                    # x = x[random.randint(0, n - 1)]  # random selection
                    x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
                elif parent == 'weighted':
                    # 对hyp乘上对应的权重融合层一个hpy, 再取平均(除以权重和)
                    x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination

                # Mutate 突变(超参数进化)
                mp, s = 0.8, 0.2  # mutation probability, sigma:突变概率
                npr = np.random
                # 根据时间设置随机数种子
                npr.seed(int(time.time()))
                # 获取突变初始值, 也就是meta三个值的第一个数据
                # 三个数值分别对应着: 变异初始概率, 最低限值, 最大限值(mutation scale 0-1, lower_limit, upper_limit)
                g = np.array([meta[k][0] for k in hyp.keys()])  # gains 0-1
                ng = len(meta)
                # 确保至少其中有一个超参变异了
                v = np.ones(ng)
                # 设置突变
                while all(v == 1):  # mutate until a change occurs (prevent duplicates)
                    v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
                # 将突变添加到base hyp上
                for i, k in enumerate(hyp.keys()):  # plt.hist(v.ravel(), 300)
                    hyp[k] = float(x[i + 7] * v[i])  # mutate

            # Constrain to limits 限制hyp在规定范围内
            for k, v in meta.items():
                # 这里的hyp是超参数配置文件对象
                # 而这里的k和v是在元超参数中遍历出来的
                # hyp的v是一个数,而元超参数的v是一个元组
                hyp[k] = max(hyp[k], v[1])  # 先限定最小值,选择二者之间的大值 ,这一步是为了防止hyp中的值过小
                hyp[k] = min(hyp[k], v[2])  # 再限定最大值,选择二者之间的小值
                hyp[k] = round(hyp[k], 5)  # 四舍五入到小数点后五位
                # 最后的值应该是 hyp中的值与 meta的最大值之间的较小者

            # Train mutation 使用突变后的参超,测试其效果
            results = train(hyp.copy(), opt, device, callbacks)

            # Write mutation results
            # 将结果写入results,并将对应的hyp写到evolve.txt,evolve.txt中每一行为一次进化的结果
            # 每行前七个数字 (P, R, mAP, F1, test_losses(GIOU, obj, cls)) 之后为hyp
            # 保存hyp到yaml文件
            print_mutation(hyp.copy(), results, yaml_file, opt.bucket)

        # Plot results 将结果可视化 / 输出保存信息
        plot_evolve(evolve_csv)
        LOGGER.info(f'Hyperparameter evolution finished\\n'
                    f"Results saved to colorstr('bold', save_dir)\\n"
                    f'Use best hyperparameters example: $ python train.py --hyp evolve_yaml')

这段代码是开始超参数进化训练

超参数进化的步骤如下:

  • 1.若存在evolve.csv文件,读取文件中的训练数据,选择超参进化方式,结果最优的训练数据突变超参数
  • 2.限制超参进化参数hyp在规定范围内
  • 3.使用突变后的超参数进行训练,测试其效果
  • 4.训练结束后,将训练结果可视化,输出保存信息保存至evolution.csv,用于下一次的超参数突变。

原理:根据生物进化,优胜劣汰,适者生存的原则,每次迭代都会保存更优秀的结果,直至迭代结束。最后的结果即为最优的超参数

注意:使用超参数进化时要经过至少300次迭代,每次迭代都会经过一次完整的训练。因此超参数进化及其耗时,大家需要根据自己需求慎用。


🚀三、设置opt参数

=============================================三、设置opt参数==================================================='''

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    # 预训练权重文件
    parser.add_argument('--weights', type=str, default=ROOT / 'pretrained/yolov5s.pt', help='initial weights path')
    # 训练模型
    parser.add_argument('--cfg', type=str, default=ROOT / 'models/yolov5s.yaml', help='model.yaml path')
    # 训练路径,包括训练集,验证集,测试集的路径,类别总数等
    parser.add_argument('--data', type=str, default=ROOT / 'data/fire_data.yaml', help='dataset.yaml path')
    # hpy超参数设置文件(lr/sgd/mixup)./data/hyps/下面有5个超参数设置文件,每个文件的超参数初始值有细微区别,用户可以根据自己的需求选择其中一个
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
    # epochs: 训练轮次, 默认轮次为300次
    parser.add_argument('--epochs', type=int, default=300)
    # batchsize: 训练批次, 默认bs=16
    parser.add_argument('--batch-size', type=int, default=4, help='total batch size for all GPUs, -1 for autobatch')
    # imagesize: 设置图片大小, 默认640*640
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
    # rect: 是否采用矩形训练,默认为False
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    # resume: 是否接着上次的训练结果,继续训练
    # 矩形训练:将比例相近的图片放在一个batch(由于batch里面的图片shape是一样的)
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
    # nosave: 不保存模型  默认False(保存)  在./runs/exp*/train/weights/保存两个模型 一个是最后一次的模型 一个是最好的模型
    # best.pt/ last.pt 不建议运行代码添加 --nosave
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    # noval: 最后进行测试, 设置了之后就是训练结束都测试一下, 不设置每轮都计算mAP, 建议不设置
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
    # noautoanchor: 不自动调整anchor, 默认False, 自动调整anchor
    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
    # evolve: 参数进化, 遗传算法调参
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
    # bucket: 谷歌优盘 / 一般用不到
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    # cache: 是否提前缓存图片到内存,以加快训练速度,默认False
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    # mage-weights: 使用图片采样策略,默认不使用
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
    # device: 设备选择
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    # parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
    # multi-scale 是否进行多尺度训练
    parser.add_argument('--multi-scale', default=True, help='vary img-size +/- 50%%')
    # single-cls: 数据集是否多类/默认True
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
    # optimizer: 优化器选择 / 提供了三种优化器
    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
    # sync-bn: 是否使用跨卡同步BN,在DDP模式使用
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    # dataloader的最大worker数量 (使用多线程加载图片)
    parser.add_argument('--workers', type=int, default=0, help='max dataloader workers (per RANK in DDP mode)')
    # 训练结果的保存路径
    parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
    # 训练结果的文件名称
    parser.add_argument('--name', default='exp', help='save to project/name')
    # 项目位置是否存在 / 默认是都不存在
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    # 四元数据加载器: 允许在较低 --img 尺寸下进行更高 --img 尺寸训练的一些好处。
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
    # cos-lr: 余弦学习率
    parser.add_argument('--linear-lr', action='store_true', help='linear LR')
    # 标签平滑 / 默认不增强, 用户可以根据自己标签的实际情况设置这个参数,建议设置小一点 0.1 / 0.05
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    # 早停止耐心次数 / 100次不更新就停止训练
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
    # --freeze冻结训练 可以设置 default = [0] 数据量大的情况下,建议不设置这个参数
    parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
    # --save-period 多少个epoch保存一下checkpoint
    parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
    # --local_rank 进程编号 / 多卡使用
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
    # Weights & Biases arguments
    # 在线可视化工具,类似于tensorboard工具
    parser.add_argument('--entity', default=None, help='W&B: Entity')
    # upload_dataset: 是否上传dataset到wandb tabel(将数据集作为交互式 dsviz表 在浏览器中查看、查询、筛选和分析数据集) 默认False
    parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
    # bbox_interval: 设置界框图像记录间隔 Set bounding-box image logging interval for W&B 默认-1   opt.epochs // 10
    parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
    # 使用数据的版本
    parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')

    # 作用就是当仅获取到基本设置时,如果运行命令中传入了之后才会获取到的其他配置,不会报错;而是将多出来的部分保存起来,留到后面使用
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

opt参数解析: 

  • cfg:   模型配置文件,网络结构
  • data:   数据集配置文件,数据集路径,类名等
  • hyp:   超参数文件
  • epochs:   训练总轮次
  • batch-size:   批次大小
  • img-size:   输入图片分辨率大小
  • rect:   是否采用矩形训练,默认False
  • resume:   接着打断训练上次的结果接着训练
  • nosave:   不保存模型,默认False
  • notest:   不进行test,默认False
  • noautoanchor:   不自动调整anchor,默认False
  • evolve:   是否进行超参数进化,默认False
  • bucket:   谷歌云盘bucket,一般不会用到
  • cache-images:   是否提前缓存图片到内存,以加快训练速度,默认False
  • weights:   加载的权重文件
  • name:   数据集名字,如果设置:results.txt to results_name.txt,默认无
  • device:   训练的设备,cpu;0(表示一个gpu设备cuda:0);0,1,2,3(多个gpu设备)
  • multi-scale:   是否进行多尺度训练,默认False
  • single-cls:    数据集是否只有一个类别,默认False
  • adam:   是否使用adam优化器
  • sync-bn:   是否使用跨卡同步BN,在DDP模式使用
  • local_rank:   gpu编号
  • logdir:   存放日志的目录 workers: dataloader的最大worker数量

(关于调参,推荐大家看@迪菲赫尔曼大佬的这篇文章:手把手带你调参YOLOv5 (v5.0-v7.0)(验证)


🚀四、执行train()函数

4.1 加载参数和初始化配置信息

4.1.1 载入参数

''' =====================1.载入参数和初始化配置信息==========================  '''
    '''
    1.1 载入参数
    '''
def train(hyp,  # 超参数 可以是超参数配置文件的路径或超参数字典 path/to/hyp.yaml or hyp
          opt,  # main中opt参数
          device,  # 当前设备
          callbacks  # 用于存储Loggers日志记录器中的函数,方便在每个训练阶段控制日志的记录情况
          ):
    # 从opt获取参数。日志保存路径,轮次、批次、权重、进程序号(主要用于分布式训练)等
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \\
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \\
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze


 

这段代码是接收传来的参数

  • hyp:  超参数,不使用超参数进化的前提下也可以从opt中获取
  • opt:  全部的命令行参数
  • device:  指的是装载程序的设备
  • callbacks:  指的是训练过程中产生的一些参数

4.1.2 创建训练权重目录和保存路径

    '''
    1.2 创建训练权重目录,设置模型、txt等保存的路径
    '''
    # Directories 获取记录训练日志的保存路径
    # 设置保存权重路径 如runs/train/exp1/weights
    w = save_dir / 'weights'  # weights dir
    # 新建文件夹 weights train evolve
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
    # 保存训练结果的目录,如last.pt和best.pt
    last, best = w / 'last.pt', w / 'best.pt'

这段代码主要是创建权重文件保存路径,权重名字和训练日志txt文件

每次训练结束后,系统会产生两个模型,一个是last.pt,一个是best.pt。顾名思义,last.pt即为训练最后一轮产生的模型,而best.pt是训练过程中,效果最好的模型。

然后创建文件夹,保存训练结果的模型文件路径 以及验证集输出结果的txt文件路径,包含迭代的次数,占用显存大小,图片尺寸,精确率,召回率,位置损失,类别损失,置信度损失和map等。


4.1.3 读取超参数配置文件

    '''
    1.3 读取hyp(超参数)配置文件
    '''
    # Hyperparameters 加载超参数
    if isinstance(hyp, str): # isinstance()是否是已知类型。 判断hyp是字典还是字符串
        # 若hyp是字符串,即认定为路径,则加载超参数为字典
        with open(hyp, errors='ignore') as f:
            # 加载yaml文件
            hyp = yaml.safe_load(f)  # load hyps dict 加载超参信息
    # 打印超参数 彩色字体
    LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'k=v' for k, v in hyp.items()))

   

这段代码主要是加载一些训练过程中需要使用的超参数,并打印出来

首先,检查超参数是字典还是字符串,若为字符串,则认定为.yaml文件路径,再将yaml文件加载为字典。这里导致超参数的数据类型不同的原因是,超参数进化时,传入train()函数的超参数即为字典。而从命令行参数中读取的则为文件路径。

然后将打印这些超参数。


4.1.4 设置参数的保存路径

     '''
    1.4 设置参数的保存路径
    '''
    # Save run settings 保存训练中的参数hyp和opt
    with open(save_dir / 'hyp.yaml', 'w') as f:
        # 保存超参数为yaml配置文件
        yaml.safe_dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        # 保存命令行参数为yaml配置文件
        yaml.safe_dump(vars(opt), f, sort_keys=False)
        # 定义数据集字典
    data_dict = None

这段代码主要是将训练的相关参数全部写入 

将本次运行的超参数(hyp)选项操作(opt)给保存成yaml格式,保存在了每次训练得到的exp文件中,这两个yaml显示了我们本次训练所选择的hyp超参数和opt参数。

还有一点,yaml.safe_load(f)加载yaml的标准函数接口,保存超参数为yaml配置文件。 yaml.safe_dump()yaml文件序列化,保存命令行参数为yaml配置文件。
vars(opt) 的作用是把数据类型是Namespace的数据转换为字典的形式。


4.1.5 加载日志信息

    '''
    1.5 加载相关日志功能:如tensorboard,logger,wandb
    '''
    # Loggers 设置wandb和tb两种日志, wandb和tensorboard

以上是关于YOLOv5源码逐行超详细注释与解读——训练部分train.py的主要内容,如果未能解决你的问题,请参考以下文章

YOLOv5源码逐行超详细注释与解读——项目目录结构解析

YOLOv5源码逐行超详细注释与解读——网络结构yolo.py

YOLOv5 最详细的源码逐行解读

YOLOV5详细解读

YOLOv5全面解析教程①:网络结构逐行代码解读

Yolov5部署训练及代码解读