YOLOv5源码逐行超详细注释与解读——训练部分train.py
Posted 路人贾'ω'
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了YOLOv5源码逐行超详细注释与解读——训练部分train.py相关的知识,希望对你有一定的参考价值。
前言
本篇文章主要是对YOLOv5项目的训练部分train.py。通常这个文件主要是用来读取用户自己的数据集,加载模型并训练。
文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑!
友情提示:全文近5万字,可以先点☆收藏☆再慢慢看哦~
源码下载地址:mirrors / ultralytics / yolov5 · GitCode
YOLOv5源码逐行讲解系列前期回顾:
YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析
YOLOv5源码逐行超详细注释与解读(2)——推理部分detect.py
目录
🚀一、导包和基本配置
1.1 Usage
"""
Train a YOLOv5 model on a custom dataset
在数据集上训练 yolo v5 模型
Usage:
$ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
训练数据为coco128 coco128数据集中有128张图片 80个类别,是规模较小的数据集
"""
这里是开头作者注释的一个部分,意在说明一些项目基本情况。
第一行表示我们用的模型是YOLOv5;
第二行表示我们传入的data数据集是coco128数据集,有128张图片,80个类别,使用的权重模型是yolov5s模型,–img表示图片大小640。
1.2 导入安装好的python库
'''======================1.导入安装好的python库====================='''
import argparse # 解析命令行参数模块
import math # 数学公式模块
import os # 与操作系统进行交互的模块 包含文件路径操作和解析
import random # 生成随机数模块
import sys # sys系统模块 包含了与Python解释器和它的环境有关的函数
import time # 时间模块 更底层
from copy import deepcopy # 深度拷贝模块
from datetime import datetime # datetime模块能以更方便的格式显示日期或对日期进行运算。
from pathlib import Path # Path将str转换为Path对象 使字符串路径易于操作的模块
import numpy as np # numpy数组操作模块
import torch # 引入torch
import torch.distributed as dist # 分布式训练模块
import torch.nn as nn # 对torch.nn.functional的类的封装 有很多和torch.nn.functional相同的函数
import yaml # yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件。
from torch.cuda import amp # PyTorch amp自动混合精度训练模块
from torch.nn.parallel import DistributedDataParallel as DDP # 多卡训练模块
from torch.optim import SGD, Adam, lr_scheduler # tensorboard模块
from tqdm import tqdm # 进度条模块
首先,导入一下常用的python库:
- argparse: 它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息
- math: 调用这个库进行数学运算
- os: 它提供了多种操作系统的接口。通过os模块提供的操作系统接口,我们可以对操作系统里文件、终端、进程等进行操作
- random: 是使用随机数的Python标准库。random库主要用于生成随机数
- sys: 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 Python 运行时环境的不同部分
- time: Python中处理时间的标准库,是最基础的时间处理库
- copy: Python 中赋值语句不复制对象,而是在目标和对象之间创建绑定 (bindings) 关系。copy模块提供了通用的浅层复制和深层复制操作
- datetime: 是Python常用的一个库,主要用于时间解析和计算
- pathlib: 这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读
然后再导入一些 pytorch库:
- numpy: 科学计算库,提供了矩阵,线性代数,傅立叶变换等等的解决方案, 最常用的是它的N维数组对象
- torch: 这是主要的Pytorch库。它提供了构建、训练和评估神经网络的工具
- torch.distributed: torch.distributed包提供Pytorch支持和通信基元,对多进程并行,在一个或多个机器上运行的若干个计算阶段
- torch.nn: torch下包含用于搭建神经网络的modules和可用于继承的类的一个子包
- yaml: yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件
- torch.cuda.amp: 自动混合精度训练 —— 节省显存并加快推理速度
- torch.nn.parallel: 构建分布式模型,并行加速程度更高,且支持多节点多gpu的硬件拓扑结构
- torch.optim: 优化器 Optimizer。主要是在模型训练阶段对模型可学习参数进行更新,常用优化器有 SGD,RMSprop,Adam等
- tqdm: 就是我们看到的训练时进度条显示
1.3 获取当前文件的绝对路径
'''===================2.获取当前文件的绝对路径========================'''
FILE = Path(__file__).resolve() # __file__指的是当前文件(即train.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/train.py
ROOT = FILE.parents[0] # YOLOv5 root directory ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path: # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径
sys.path.append(str(ROOT)) # add ROOT to PATH 把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT设置为相对路径
这段代码会获取当前文件的绝对路径,并使用Path库将其转换为Path对象。
这一部分的主要作用有两个:
- 将当前项目添加到系统路径上,以使得项目中的模块可以调用。
- 将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件。
1.4 加载自定义模块
'''===================3..加载自定义模块============================'''
import val # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, NCOLS, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:
- val: 这个是测试集,我们下一篇再具体讲
- models.experimental: 实验性质的代码,包括MixConv2d、跨层权重Sum等
- models.yolo: yolo的特定模块,包括BaseModel,DetectionModel,ClassificationModel,parse_model等
- utils.autoanchor: 定义了自动生成锚框的方法
- utils.autobatch: 定义了自动生成批量大小的方法
- utils.callbacks: 定义了回调函数,主要为logger服务
- utils.datasets: dateset和dateloader定义代码
- utils.downloads: 谷歌云盘内容下载
- utils.general.py: 定义了一些常用的工具函数,比如检查文件是否存在、检查图像大小是否符合要求、打印命令行参数等等
- utils.loggers : 日志打印
- utils.loss: 存放各种损失函数
- utils.metrics: 模型验证指标,包括ap,混淆矩阵等
- utils.plots.py: 定义了Annotator类,可以在图像上绘制矩形框和标注信息
- utils.torch_utils.py: 定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等
通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。
1.5 分布式训练初始化
'''================4.分布式训练初始化==========================='''
# https://pytorch.org/docs/stable/elastic/run.html该网址有详细介绍
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # -本地序号。这个 Worker 是这台机器上的第几个 Worker
RANK = int(os.getenv('RANK', -1)) # -进程序号。这个 Worker 是全局第几个 Worker
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # 总共有几个 Worker
'''
查找名为LOCAL_RANK,RANK,WORLD_SIZE的环境变量,
若存在则返回环境变量的值,若不存在则返回第二个参数(-1,默认None)
rank和local_rank的区别: 两者的区别在于前者用于进程间通讯,后者用于本地设备分配。
'''
接下来是设置分布式训练时所需的环境变量。分布式训练指的是多GPU训练,将训练参数分布在多个GPU上进行训练,有利于提升训练效率。
🚀二、执行main()函数
2.1 检查分布式训练环境
def main(opt, callbacks=Callbacks()):
'''
2.1 检查分布式训练环境
'''
# Checks
if RANK in [-1, 0]: # 若进程编号为-1或0
# 输出所有训练参数 / 参数以彩色的方式表现
print_args(FILE.stem, opt)
# 检测YOLO v5的github仓库是否更新,若已更新,给出提示
check_git_status()
# 检查requirements.txt所需包是否都满足
check_requirements(exclude=['thop'])
这段代码主要是检查分布式训练的环境。
若RANK为-1或0,会执行下面三行代码,打印参数并检查github仓库和依赖库。
- 第一行代码,负责打印文件所用到的参数信息,这个参数包括命令行传入进去的参数以及默认参数
- 第二行代码,检查yolov5的github仓库是否更新,如果更新的话,会有一个提示
- 第三行代码,检查requirements中要求的安装包有没有正确安装成功,没有成功的话会给予一定的提示
2.2 判断是否断点续训
'''
2.2 判断是否断点续训
'''
# Resume
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
# isinstance()是否是已经知道的类型
# 如果resume是True,则通过get_lastest_run()函数找到runs为文件夹中最近的权重文件last.pt
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
# 判断是否为文件,若不是文件抛出异常
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
# opt.yaml是训练时的命令行参数文件
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
# 超参数替换,将训练时的命令行参数加载进opt参数对象中
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
# opt.cfg设置为'' 对应着train函数里面的操作(加载权重时是否加载权重里的anchor)
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
# 打印从ckpt恢复断点训练信息
LOGGER.info(f'Resuming training from ckpt')
else:
# 不使用断点续训,就从文件中读取相关参数
# check_file (utils/general.py)的作用为查找/下载文件 并返回该文件的路径。
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \\
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
# 如果模型文件和权重文件为空,弹出警告
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
# 如果要进行超参数进化,重建保存路径
if opt.evolve:
# 设置新的项目输出目录
opt.project = str(ROOT / 'runs/evolve')
# 将resume传递给exist_ok
opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
# 根据opt.project生成目录,并赋值给opt.save_dir 如: runs/train/exp1
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
这段代码主要是关于断点训练的判断和准备。
断点训练是当训练异常终止或想调节超参数时,系统会保留训练异常终止前的超参数与训练参数,当下次训练开始时,并不会从头开始,而是从上次中断的地方继续训练。
- 使用断点续训,就从last.pt中读取相关参数
- 不使用断点续训,就从文件中读取相关参数
2.3 判断是否分布式训练
'''
2.3 判断是否分布式训练
'''
# DDP mode --> 支持多机多卡、分布式训练
# 选择程序装载的位置
device = select_device(opt.device, batch_size=opt.batch_size)
# 当进程内的GPU编号不为-1时,才会进入DDP
if LOCAL_RANK != -1:
# 用于DDP训练的GPU数量不足
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
# WORLD_SIZE表示全局的进程数
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
# 不能使用图片采样策略
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
# 不能使用超参数进化
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
# 设置装载程序设备
torch.cuda.set_device(LOCAL_RANK)
# 保存装载程序的设备
device = torch.device('cuda', LOCAL_RANK)
# torch.distributed是用于多GPU训练的模块
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
这段代码主要是检查DDP训练的配置,并设置GPU。
DDP(Distributed Data Parallel)用于单机或多机的多GPU分布式训练,但目前DDP只能在Linux下使用。这部分它会选择你是使用cpu还是gpu,假如你采用的是分布式训练的话,它就会额外执行下面的一些操作,我们这里一般不会用到分布式,所以也就没有执行什么东西。
2.4 判断是否进化训练
'''
2.4 判断是否进化训练
'''
# Train 训练模式: 如果不进行超参数进化,则直接调用train()函数,开始训练
if not opt.evolve:# 如果不使用超参数进化
# 开始训练
train(opt.hyp, opt, device, callbacks)
if WORLD_SIZE > 1 and RANK == 0:
# 如果全局进程数大于1并且RANK等于0
# 日志输出 销毁进程组
LOGGER.info('Destroying process group... ')
# 训练完毕,销毁所有进程
dist.destroy_process_group()
这段代码是不进行进化训练的情况,此时正常训练。
如果输入evolve会执行else下面这些代码,因为我们没有输入evolve并且不是分布式训练,因此会执行train函数。也就是说,当不使用超参数进化训练时,直接把命令行参数传入train函数,训练完成后销毁所有进程。
接下来我们再看看使用超参数进化训练的情况:
# Evolve hyperparameters (optional) 遗传进化算法,边进化边训练
else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
# 超参数列表(突变范围 - 最小值 - 最大值)
meta = 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'box': (1, 0.02, 0.2), # box loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0) # segment copy-paste (probability)
# 加载默认超参数
with open(opt.hyp, errors='ignore') as f:
hyp = yaml.safe_load(f) # load hyps dict
# 如果超参数文件中没有'anchors',则设为3
if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3
# 使用进化算法时,仅在最后的epoch测试和保存
opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
if opt.bucket:
os.system(f'gsutil cp gs://opt.bucket/evolve.csv save_dir') # download evolve.csv if exists
"""
遗传算法调参:遵循适者生存、优胜劣汰的法则,即寻优过程中保留有用的,去除无用的。
遗传算法需要提前设置4个参数: 群体大小/进化代数/交叉概率/变异概率
"""
这段代码是使用超参数进化训练的前期准备
首先指定每个超参数的突变范围、最大值、最小值,再为超参数的结果保存做好准备。
# 选择超参数的遗传迭代次数 默认为迭代300次
for _ in range(opt.evolve): # generations to evolve
# 如果evolve.csv文件存在
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
# 选择超参进化方式,只用single和weighted两种
parent = 'single' # parent selection method: 'single' or 'weighted'
# 加载evolve.txt
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
# 选取至多前五次进化的结果
n = min(5, len(x)) # number of previous results to consider
# fitness()为x前四项加权 [P, R, mAP@0.5, mAP@0.5:0.95]
# np.argsort只能从小到大排序, 添加负号实现从大到小排序, 算是排序的一个代码技巧
x = x[np.argsort(-fitness(x))][:n] # top n mutations
# 根据(mp, mr, map50, map)的加权和来作为权重计算hyp权重
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
# 根据不同进化方式获得base hyp
if parent == 'single' or len(x) == 1:
# 根据权重的几率随机挑选适应度历史前5的其中一个
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
# 对hyp乘上对应的权重融合层一个hpy, 再取平均(除以权重和)
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate 突变(超参数进化)
mp, s = 0.8, 0.2 # mutation probability, sigma:突变概率
npr = np.random
# 根据时间设置随机数种子
npr.seed(int(time.time()))
# 获取突变初始值, 也就是meta三个值的第一个数据
# 三个数值分别对应着: 变异初始概率, 最低限值, 最大限值(mutation scale 0-1, lower_limit, upper_limit)
g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
ng = len(meta)
# 确保至少其中有一个超参变异了
v = np.ones(ng)
# 设置突变
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
# 将突变添加到base hyp上
for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
hyp[k] = float(x[i + 7] * v[i]) # mutate
# Constrain to limits 限制hyp在规定范围内
for k, v in meta.items():
# 这里的hyp是超参数配置文件对象
# 而这里的k和v是在元超参数中遍历出来的
# hyp的v是一个数,而元超参数的v是一个元组
hyp[k] = max(hyp[k], v[1]) # 先限定最小值,选择二者之间的大值 ,这一步是为了防止hyp中的值过小
hyp[k] = min(hyp[k], v[2]) # 再限定最大值,选择二者之间的小值
hyp[k] = round(hyp[k], 5) # 四舍五入到小数点后五位
# 最后的值应该是 hyp中的值与 meta的最大值之间的较小者
# Train mutation 使用突变后的参超,测试其效果
results = train(hyp.copy(), opt, device, callbacks)
# Write mutation results
# 将结果写入results,并将对应的hyp写到evolve.txt,evolve.txt中每一行为一次进化的结果
# 每行前七个数字 (P, R, mAP, F1, test_losses(GIOU, obj, cls)) 之后为hyp
# 保存hyp到yaml文件
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
# Plot results 将结果可视化 / 输出保存信息
plot_evolve(evolve_csv)
LOGGER.info(f'Hyperparameter evolution finished\\n'
f"Results saved to colorstr('bold', save_dir)\\n"
f'Use best hyperparameters example: $ python train.py --hyp evolve_yaml')
这段代码是开始超参数进化训练。
超参数进化的步骤如下:
- 1.若存在evolve.csv文件,读取文件中的训练数据,选择超参进化方式,结果最优的训练数据突变超参数
- 2.限制超参进化参数hyp在规定范围内
- 3.使用突变后的超参数进行训练,测试其效果
- 4.训练结束后,将训练结果可视化,输出保存信息保存至evolution.csv,用于下一次的超参数突变。
原理:根据生物进化,优胜劣汰,适者生存的原则,每次迭代都会保存更优秀的结果,直至迭代结束。最后的结果即为最优的超参数
注意:使用超参数进化时要经过至少300次迭代,每次迭代都会经过一次完整的训练。因此超参数进化及其耗时,大家需要根据自己需求慎用。
🚀三、设置opt参数
=============================================三、设置opt参数==================================================='''
def parse_opt(known=False):
parser = argparse.ArgumentParser()
# 预训练权重文件
parser.add_argument('--weights', type=str, default=ROOT / 'pretrained/yolov5s.pt', help='initial weights path')
# 训练模型
parser.add_argument('--cfg', type=str, default=ROOT / 'models/yolov5s.yaml', help='model.yaml path')
# 训练路径,包括训练集,验证集,测试集的路径,类别总数等
parser.add_argument('--data', type=str, default=ROOT / 'data/fire_data.yaml', help='dataset.yaml path')
# hpy超参数设置文件(lr/sgd/mixup)./data/hyps/下面有5个超参数设置文件,每个文件的超参数初始值有细微区别,用户可以根据自己的需求选择其中一个
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
# epochs: 训练轮次, 默认轮次为300次
parser.add_argument('--epochs', type=int, default=300)
# batchsize: 训练批次, 默认bs=16
parser.add_argument('--batch-size', type=int, default=4, help='total batch size for all GPUs, -1 for autobatch')
# imagesize: 设置图片大小, 默认640*640
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
# rect: 是否采用矩形训练,默认为False
parser.add_argument('--rect', action='store_true', help='rectangular training')
# resume: 是否接着上次的训练结果,继续训练
# 矩形训练:将比例相近的图片放在一个batch(由于batch里面的图片shape是一样的)
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
# nosave: 不保存模型 默认False(保存) 在./runs/exp*/train/weights/保存两个模型 一个是最后一次的模型 一个是最好的模型
# best.pt/ last.pt 不建议运行代码添加 --nosave
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
# noval: 最后进行测试, 设置了之后就是训练结束都测试一下, 不设置每轮都计算mAP, 建议不设置
parser.add_argument('--noval', action='store_true', help='only validate final epoch')
# noautoanchor: 不自动调整anchor, 默认False, 自动调整anchor
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
# evolve: 参数进化, 遗传算法调参
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
# bucket: 谷歌优盘 / 一般用不到
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
# cache: 是否提前缓存图片到内存,以加快训练速度,默认False
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
# mage-weights: 使用图片采样策略,默认不使用
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
# device: 设备选择
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
# parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
# multi-scale 是否进行多尺度训练
parser.add_argument('--multi-scale', default=True, help='vary img-size +/- 50%%')
# single-cls: 数据集是否多类/默认True
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
# optimizer: 优化器选择 / 提供了三种优化器
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
# sync-bn: 是否使用跨卡同步BN,在DDP模式使用
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
# dataloader的最大worker数量 (使用多线程加载图片)
parser.add_argument('--workers', type=int, default=0, help='max dataloader workers (per RANK in DDP mode)')
# 训练结果的保存路径
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
# 训练结果的文件名称
parser.add_argument('--name', default='exp', help='save to project/name')
# 项目位置是否存在 / 默认是都不存在
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
# 四元数据加载器: 允许在较低 --img 尺寸下进行更高 --img 尺寸训练的一些好处。
parser.add_argument('--quad', action='store_true', help='quad dataloader')
# cos-lr: 余弦学习率
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
# 标签平滑 / 默认不增强, 用户可以根据自己标签的实际情况设置这个参数,建议设置小一点 0.1 / 0.05
parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
# 早停止耐心次数 / 100次不更新就停止训练
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
# --freeze冻结训练 可以设置 default = [0] 数据量大的情况下,建议不设置这个参数
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
# --save-period 多少个epoch保存一下checkpoint
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
# --local_rank 进程编号 / 多卡使用
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
# Weights & Biases arguments
# 在线可视化工具,类似于tensorboard工具
parser.add_argument('--entity', default=None, help='W&B: Entity')
# upload_dataset: 是否上传dataset到wandb tabel(将数据集作为交互式 dsviz表 在浏览器中查看、查询、筛选和分析数据集) 默认False
parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
# bbox_interval: 设置界框图像记录间隔 Set bounding-box image logging interval for W&B 默认-1 opt.epochs // 10
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
# 使用数据的版本
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
# 作用就是当仅获取到基本设置时,如果运行命令中传入了之后才会获取到的其他配置,不会报错;而是将多出来的部分保存起来,留到后面使用
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt
opt参数解析:
- cfg: 模型配置文件,网络结构
- data: 数据集配置文件,数据集路径,类名等
- hyp: 超参数文件
- epochs: 训练总轮次
- batch-size: 批次大小
- img-size: 输入图片分辨率大小
- rect: 是否采用矩形训练,默认False
- resume: 接着打断训练上次的结果接着训练
- nosave: 不保存模型,默认False
- notest: 不进行test,默认False
- noautoanchor: 不自动调整anchor,默认False
- evolve: 是否进行超参数进化,默认False
- bucket: 谷歌云盘bucket,一般不会用到
- cache-images: 是否提前缓存图片到内存,以加快训练速度,默认False
- weights: 加载的权重文件
- name: 数据集名字,如果设置:results.txt to results_name.txt,默认无
- device: 训练的设备,cpu;0(表示一个gpu设备cuda:0);0,1,2,3(多个gpu设备)
- multi-scale: 是否进行多尺度训练,默认False
- single-cls: 数据集是否只有一个类别,默认False
- adam: 是否使用adam优化器
- sync-bn: 是否使用跨卡同步BN,在DDP模式使用
- local_rank: gpu编号
- logdir: 存放日志的目录 workers: dataloader的最大worker数量
(关于调参,推荐大家看@迪菲赫尔曼大佬的这篇文章:手把手带你调参YOLOv5 (v5.0-v7.0)(验证))
🚀四、执行train()函数
4.1 加载参数和初始化配置信息
4.1.1 载入参数
''' =====================1.载入参数和初始化配置信息========================== '''
'''
1.1 载入参数
'''
def train(hyp, # 超参数 可以是超参数配置文件的路径或超参数字典 path/to/hyp.yaml or hyp
opt, # main中opt参数
device, # 当前设备
callbacks # 用于存储Loggers日志记录器中的函数,方便在每个训练阶段控制日志的记录情况
):
# 从opt获取参数。日志保存路径,轮次、批次、权重、进程序号(主要用于分布式训练)等
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \\
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \\
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
这段代码是接收传来的参数。
- hyp: 超参数,不使用超参数进化的前提下也可以从opt中获取
- opt: 全部的命令行参数
- device: 指的是装载程序的设备
- callbacks: 指的是训练过程中产生的一些参数
4.1.2 创建训练权重目录和保存路径
'''
1.2 创建训练权重目录,设置模型、txt等保存的路径
'''
# Directories 获取记录训练日志的保存路径
# 设置保存权重路径 如runs/train/exp1/weights
w = save_dir / 'weights' # weights dir
# 新建文件夹 weights train evolve
(w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
# 保存训练结果的目录,如last.pt和best.pt
last, best = w / 'last.pt', w / 'best.pt'
这段代码主要是创建权重文件保存路径,权重名字和训练日志txt文件
每次训练结束后,系统会产生两个模型,一个是last.pt,一个是best.pt。顾名思义,last.pt即为训练最后一轮产生的模型,而best.pt是训练过程中,效果最好的模型。
然后创建文件夹,保存训练结果的模型文件路径 以及验证集输出结果的txt文件路径,包含迭代的次数,占用显存大小,图片尺寸,精确率,召回率,位置损失,类别损失,置信度损失和map等。
4.1.3 读取超参数配置文件
'''
1.3 读取hyp(超参数)配置文件
'''
# Hyperparameters 加载超参数
if isinstance(hyp, str): # isinstance()是否是已知类型。 判断hyp是字典还是字符串
# 若hyp是字符串,即认定为路径,则加载超参数为字典
with open(hyp, errors='ignore') as f:
# 加载yaml文件
hyp = yaml.safe_load(f) # load hyps dict 加载超参信息
# 打印超参数 彩色字体
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'k=v' for k, v in hyp.items()))
这段代码主要是加载一些训练过程中需要使用的超参数,并打印出来
首先,检查超参数是字典还是字符串,若为字符串,则认定为.yaml文件路径,再将yaml文件加载为字典。这里导致超参数的数据类型不同的原因是,超参数进化时,传入train()函数的超参数即为字典。而从命令行参数中读取的则为文件路径。
然后将打印这些超参数。
4.1.4 设置参数的保存路径
'''
1.4 设置参数的保存路径
'''
# Save run settings 保存训练中的参数hyp和opt
with open(save_dir / 'hyp.yaml', 'w') as f:
# 保存超参数为yaml配置文件
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
# 保存命令行参数为yaml配置文件
yaml.safe_dump(vars(opt), f, sort_keys=False)
# 定义数据集字典
data_dict = None
这段代码主要是将训练的相关参数全部写入
将本次运行的超参数(hyp)和选项操作(opt)给保存成yaml格式,保存在了每次训练得到的exp文件中,这两个yaml显示了我们本次训练所选择的hyp超参数和opt参数。
还有一点,yaml.safe_load(f)是加载yaml的标准函数接口,保存超参数为yaml配置文件。 yaml.safe_dump()是将yaml文件序列化,保存命令行参数为yaml配置文件。vars(opt)
的作用是把数据类型是Namespace的数据转换为字典的形式。
4.1.5 加载日志信息
'''
1.5 加载相关日志功能:如tensorboard,logger,wandb
'''
# Loggers 设置wandb和tb两种日志, wandb和tensorboard以上是关于YOLOv5源码逐行超详细注释与解读——训练部分train.py的主要内容,如果未能解决你的问题,请参考以下文章