YOLOv5 最详细的源码逐行解读
Posted supermax2020
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了YOLOv5 最详细的源码逐行解读相关的知识,希望对你有一定的参考价值。
所用版本: v6.1
本文解读detect.py
源代码地址: YOLO v5
1. 加载系统库 27~33行
import argparse
import os
import sys
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
首先加载必要的外部库, 在使用时我们再介绍他们的用法
2. 设置系统环境 34~40行
FILE = Path(__file__).resolve() # __file__指的是当前文件(即detect.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/detect.py
ROOT = FILE.parents[0] # ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path: # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径
sys.path.append(str(ROOT)) # 把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # ROOT设置为相对路径
这一部分的主要作用有两个:
- 将当前项目添加到系统路径上,以使得项目中的模块可以调用.
- 将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件.
3. 加载自定义模块 41~47行
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。
用的时候再解释这些库/方法的作用
4. run方法 48~213行
4.1 参数列表 48~77行
@torch.no_grad()# 该标注使得方法中所有计算得出的tensor的requires_grad都自动设置为False,也就是说不会求梯度,可以加快预测效率,减小资源消耗
def run(
weights=ROOT / 'yolov5s.pt', # 事先训练完成的权重文件,比如yolov5s.pt,假如使用官方训练好的文件(比如yolov5s),则会自动下载
source=ROOT / 'data/images', # 预测时的输入数据,可以是文件/路径/URL/glob, 输入是0的话调用摄像头作为输入
data=ROOT / 'data/coco128.yaml', # 数据集文件
imgsz=(640, 640), # 预测时的放缩后图片大小(因为YOLO算法需要预先放缩图片), 两个值分别是height, width
conf_thres=0.25, # 置信度阈值, 高于此值的bounding_box才会被保留
iou_thres=0.45, # IOU阈值,高于此值的bounding_box才会被保留
max_det=1000, # 一张图片上检测的最大目标数量
device='', # 所使用的GPU编号,如果使用CPU就写cpu
view_img=False, # 是否在推理时预览图片
save_txt=False, # save results to *.txt 是否将结果保存在txt文件中
save_conf=False, # save confidences in --save-txt labels 是否将结果中的置信度保存在txt文件中
save_crop=False, # save cropped prediction boxes 是否保存裁剪后的预测框
nosave=False, # do not save images/videos 是否保存预测后的图片/视频
classes=None, # 过滤指定类的预测结果
agnostic_nms=False, # 如为True,则为class-agnostic. 否则为class-specific
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project=ROOT / 'runs/detect', # 推理结果保存的路径
name='exp', # 结果保存文件夹的命名前缀
exist_ok=False, # True: 推理结果覆盖之前的结果 False: 推理结果新建文件夹保存,文件夹名递增
line_thickness=3, # 绘制Bounding_box的线宽度
hide_labels=False, # True: 隐藏标签
hide_conf=False, # True: 隐藏置信度
half=False, # use FP16 half-precision inference 是否使用半精度推理(节约显存)
dnn=False, # use OpenCV DNN for ONNX inference
):
这里仅把一些必要的参数注释了一下,其他参数在实际使用中可以使用默认值,如果识别效果不好可以考虑修改参数(但更可能是训练的问题而不是这些参数的问题)
4.2 初始化环境,加载模型 78~105行
source = str(source)
save_img = not nosave and not source.endswith('.txt') # 是否需要保存图片,如果nosave(传入的参数)为false且source的结尾不是txt则保存图片
# 后面这个source.endswith('.txt')也就是source以.txt结尾,不过我不清楚这是什么用法
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
# 判断source是不是视频/图像文件路径
# 假如source是"D://YOLOv5/data/1.jpg",则Path(source).suffix是".jpg",Path(source).suffix[1:]是"jpg"
# 而IMG_FORMATS 和 VID_FORMATS两个变量保存的是所有的视频和图片的格式后缀。
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))# 判断source是否是链接
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)# 判断是source是否是摄像头
if is_url and is_file:
source = check_file(source) # 如果source是一个指向图片/视频的链接,则下载输入数据
# Directories
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # save_dir是保存运行结果的文件夹名,是通过递增的方式来命名的。第一次运行时路径是“runs\\detect\\exp”,第二次运行时路径是“runs\\detect\\exp1”
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # 根据前面生成的路径创建文件夹
# 加载模型
device = select_device(device)# select_device方法定义在utils.torch_utils模块中,返回值是torch.device对象,也就是推理时所使用的硬件资源。输入值如果是数字,表示GPU序号。也可是输入‘cpu’,表示使用CPU训练,默认是cpu
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)# DetectMultiBackend定义在models.common模块中,是我们要加载的网络,其中weights参数就是输入时指定的权重文件(比如yolov5s.pt)
stride, names, pt = model.stride, model.names, model.pt
# stride:推理时所用到的步长,默认为32, 大步长适合于大目标,小步长适合于小目标
# names:保存推理结果名的列表,比如默认模型的值是['person', 'bicycle', 'car', ...]
# pt: 加载的是否是pytorch模型(也就是pt格式的文件),
imgsz = check_img_size(imgsz, s=stride)
# 将图片大小调整为步长的整数倍
# 比如假如步长是10,imagesz是[100,101],则返回值是[100,100]
# Dataloader
if webcam:# 使用摄像头作为输入
view_img = check_imshow()# 检测cv2.imshow()方法是否可以执行,不能执行则抛出异常
cudnn.benchmark = True # 该设置可以加速预测
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)# 加载输入数据流
# source:输入数据源 image_size 图片识别前被放缩的大小, stride:识别时的步长,
# auto的作用可以看utils.augmentations.letterbox方法,它决定了是否需要将图片填充为正方形,如果auto=True则不需要
bs = len(dataset) # batch_size 批大小
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs# 用于保存视频,前者是视频路径,后者是一个cv2.VideoWriter对象
4.3 开始预测 106~203行
# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # 使用空白图片(零矩阵)预先用GPU跑一遍预测流程,可以加速预测
seen, windows, dt = 0, [], [0.0, 0.0, 0.0]
# seen: 已经处理完了多少帧图片
# windows: 如果需要预览图片,windows列表会给每个输入文件存储一个路径.
# dt: 存储每一步骤的耗时
for path, im, im0s, vid_cap, s in dataset:
# 在dataset中,每次迭代的返回值是self.sources, img, img0, None, ''
#path:文件路径(即source)
#im: 处理后的输入图片列表(经过了放缩操作)
#im0s: 源输入图片列表
#vid_cap
# s: 图片的基本信息,比如路径,大小
t1 = time_sync()# 获取当前时间
im = torch.from_numpy(im).to(device)#将图片放到指定设备(如GPU)上识别
im = im.half() if model.fp16 else im.float() # 把输入从整型转化为半精度/全精度浮点数。
im /= 255 # 0 - 255 to 0.0 - 1.0 #将图片归一化处理(这是图像表示方法的的规范,使用浮点数就要归一化)
if len(im.shape) == 3:
im = im[None] # 添加一个第0维。在pytorch的nn.Module的输入中,第0维是batch的大小,这里添加一个1。
t2 = time_sync() # 获取当前时间
dt[0] += t2 - t1 # 记录该阶段耗时
# Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
# 如果为True则保留推理过程中的特征图,保存在runs文件夹中
pred = model(im, augment=augment, visualize=visualize)
# 推理结果,pred保存的是所有的bound_box的信息,
t3 = time_sync()
dt[1] += t3 - t2# 记录该阶段耗时
# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# 执行非极大值抑制,返回值为过滤后的预测框
# conf_thres: 置信度阈值
# iou_thres: iou阈值
# classes: 需要过滤的类(数字列表)
# agnostic_nms: 标记class-agnostic或者使用class-specific方式。默认为class-agnostic
# max_det: 检测框结果的最大数量
dt[2] += time_sync() - t3
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
# Process predictions
for i, det in enumerate(pred): # 每次迭代处理一张图片,
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
#frame:此次取的是第几张图片
s += f'i: '# s后面拼接一个字符串i
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path
save_path = str(save_dir / p.name) # 推理结果图片保存的路径
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_frame') # 推理结果文本保存的路径
s += '%gx%g ' % im.shape[2:] # 显示推理前裁剪后的图像尺寸
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
#得到原图的宽和高
imc = im0.copy() if save_crop else im0 # for save_crop
#如果save_crop的值为true, 则将检测到的bounding_box单独保存成一张图片。
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
# 得到一个绘图的类,类中预先存储了原图、线条宽度、类名
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
# 将标注的bounding_box大小调整为和原图一致(因为训练时原图经过了放缩)
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"n names[int(c)]'s' * (n > 1), " # add to string
# 打印出所有的预测结果 比如1 person(检测出一个人)
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # 保存txt文件
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
# 将坐标转变成x y w h 的形式,并归一化
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
# line的形式是: ”类别 x y w h“,假如save_conf为true,则line的形式是:”类别 x y w h 置信度“
with open(f'txt_path.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\\n')
# 写入对应的文件夹里,路径默认为“runs\\detect\\exp*\\labels”
if save_img or save_crop or view_img: # 给图片添加推理后的bounding_box边框
c = int(cls) # 类别标号
label = None if hide_labels else (names[c] if hide_conf else f'names[c] conf:.2f')# 类别名
annotator.box_label(xyxy, label, color=colors(c, True))
#绘制边框
if save_crop:# 将预测框内的图片单独保存
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'p.stem.jpg', BGR=True)
# Stream results
im0 = annotator.result()
#im0是绘制好的图片
if view_img:# 如果view_img为true,则显示该图片
if p not in windows: # 如果当前图片/视频的路径不在windows列表里,则说明需要重新为该图片/视频创建一个预览窗口
windows.append(p)# 标记当前图片/视频已经创建好预览窗口了
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0) # 预览图片
cv2.waitKey(1) # 暂停 1 millisecond
# Save results (image with detections)
if save_img:# 如果save_img为true,则保存绘制完的图片
if dataset.mode == 'image':# 如果是图片,则保存
cv2.imwrite(save_path, im0)
else: # 如果是视频或者"流"
if vid_path[i] != save_path: # vid_path[i] != save_path,说明这张图片属于一段新的视频,需要重新创建视频文件
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
# 以上的部分是保存视频文件
# Print time (inference-only)
LOGGER.info(f'sDone. (t3 - t2:.3fs)')# 打印耗时
4.4 打印结果 204~212行
t = tuple(x / seen * 1E3 for x in dt) # 平均每张图片所耗费时间
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape (1, 3, *imgsz)' % t)
if save_txt or save_img:
s = f"\\nlen(list(save_dir.glob('labels/*.txt'))) labels saved to save_dir / 'labels'" if save_txt else ''# 标签保存的路径
LOGGER.info(f"Results saved to colorstr('bold', save_dir)s")
if update:
strip_optimizer(weights) # update model (to fix SourceChangeWarning)
5. 其余代码
parse_opt方法的解释和run的参数解释一致,不再重复解释
如果不明白如何修改参数使用请使用搜索引擎搜索ArgumentParser的用法
YOLOv5源码逐行超详细注释与解读——训练部分train.py
前言
本篇文章主要是对YOLOv5项目的训练部分train.py。通常这个文件主要是用来读取用户自己的数据集,加载模型并训练。
文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑!
友情提示:全文近5万字,可以先点☆收藏☆再慢慢看哦~
源码下载地址:mirrors / ultralytics / yolov5 · GitCode
YOLOv5源码逐行讲解系列前期回顾:
YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析
YOLOv5源码逐行超详细注释与解读(2)——推理部分detect.py
目录
🚀一、导包和基本配置
1.1 Usage
"""
Train a YOLOv5 model on a custom dataset
在数据集上训练 yolo v5 模型
Usage:
$ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
训练数据为coco128 coco128数据集中有128张图片 80个类别,是规模较小的数据集
"""
这里是开头作者注释的一个部分,意在说明一些项目基本情况。
第一行表示我们用的模型是YOLOv5;
第二行表示我们传入的data数据集是coco128数据集,有128张图片,80个类别,使用的权重模型是yolov5s模型,–img表示图片大小640。
1.2 导入安装好的python库
'''======================1.导入安装好的python库====================='''
import argparse # 解析命令行参数模块
import math # 数学公式模块
import os # 与操作系统进行交互的模块 包含文件路径操作和解析
import random # 生成随机数模块
import sys # sys系统模块 包含了与Python解释器和它的环境有关的函数
import time # 时间模块 更底层
from copy import deepcopy # 深度拷贝模块
from datetime import datetime # datetime模块能以更方便的格式显示日期或对日期进行运算。
from pathlib import Path # Path将str转换为Path对象 使字符串路径易于操作的模块
import numpy as np # numpy数组操作模块
import torch # 引入torch
import torch.distributed as dist # 分布式训练模块
import torch.nn as nn # 对torch.nn.functional的类的封装 有很多和torch.nn.functional相同的函数
import yaml # yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件。
from torch.cuda import amp # PyTorch amp自动混合精度训练模块
from torch.nn.parallel import DistributedDataParallel as DDP # 多卡训练模块
from torch.optim import SGD, Adam, lr_scheduler # tensorboard模块
from tqdm import tqdm # 进度条模块
首先,导入一下常用的python库:
- argparse: 它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息
- math: 调用这个库进行数学运算
- os: 它提供了多种操作系统的接口。通过os模块提供的操作系统接口,我们可以对操作系统里文件、终端、进程等进行操作
- random: 是使用随机数的Python标准库。random库主要用于生成随机数
- sys: 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 Python 运行时环境的不同部分
- time: Python中处理时间的标准库,是最基础的时间处理库
- copy: Python 中赋值语句不复制对象,而是在目标和对象之间创建绑定 (bindings) 关系。copy模块提供了通用的浅层复制和深层复制操作
- datetime: 是Python常用的一个库,主要用于时间解析和计算
- pathlib: 这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读
然后再导入一些 pytorch库:
- numpy: 科学计算库,提供了矩阵,线性代数,傅立叶变换等等的解决方案, 最常用的是它的N维数组对象
- torch: 这是主要的Pytorch库。它提供了构建、训练和评估神经网络的工具
- torch.distributed: torch.distributed包提供Pytorch支持和通信基元,对多进程并行,在一个或多个机器上运行的若干个计算阶段
- torch.nn: torch下包含用于搭建神经网络的modules和可用于继承的类的一个子包
- yaml: yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件
- torch.cuda.amp: 自动混合精度训练 —— 节省显存并加快推理速度
- torch.nn.parallel: 构建分布式模型,并行加速程度更高,且支持多节点多gpu的硬件拓扑结构
- torch.optim: 优化器 Optimizer。主要是在模型训练阶段对模型可学习参数进行更新,常用优化器有 SGD,RMSprop,Adam等
- tqdm: 就是我们看到的训练时进度条显示
1.3 获取当前文件的绝对路径
'''===================2.获取当前文件的绝对路径========================'''
FILE = Path(__file__).resolve() # __file__指的是当前文件(即train.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/train.py
ROOT = FILE.parents[0] # YOLOv5 root directory ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path: # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径
sys.path.append(str(ROOT)) # add ROOT to PATH 把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative ROOT设置为相对路径
这段代码会获取当前文件的绝对路径,并使用Path库将其转换为Path对象。
这一部分的主要作用有两个:
- 将当前项目添加到系统路径上,以使得项目中的模块可以调用。
- 将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件。
1.4 加载自定义模块
'''===================3..加载自定义模块============================'''
import val # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, NCOLS, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:
- val: 这个是测试集,我们下一篇再具体讲
- models.experimental: 实验性质的代码,包括MixConv2d、跨层权重Sum等
- models.yolo: yolo的特定模块,包括BaseModel,DetectionModel,ClassificationModel,parse_model等
- utils.autoanchor: 定义了自动生成锚框的方法
- utils.autobatch: 定义了自动生成批量大小的方法
- utils.callbacks: 定义了回调函数,主要为logger服务
- utils.datasets: dateset和dateloader定义代码
- utils.downloads: 谷歌云盘内容下载
- utils.general.py: 定义了一些常用的工具函数,比如检查文件是否存在、检查图像大小是否符合要求、打印命令行参数等等
- utils.loggers : 日志打印
- utils.loss: 存放各种损失函数
- utils.metrics: 模型验证指标,包括ap,混淆矩阵等
- utils.plots.py: 定义了Annotator类,可以在图像上绘制矩形框和标注信息
- utils.torch_utils.py: 定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等
通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。
1.5 分布式训练初始化
'''================4.分布式训练初始化==========================='''
# https://pytorch.org/docs/stable/elastic/run.html该网址有详细介绍
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # -本地序号。这个 Worker 是这台机器上的第几个 Worker
RANK = int(os.getenv('RANK', -1)) # -进程序号。这个 Worker 是全局第几个 Worker
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # 总共有几个 Worker
'''
查找名为LOCAL_RANK,RANK,WORLD_SIZE的环境变量,
若存在则返回环境变量的值,若不存在则返回第二个参数(-1,默认None)
rank和local_rank的区别: 两者的区别在于前者用于进程间通讯,后者用于本地设备分配。
'''
接下来是设置分布式训练时所需的环境变量。分布式训练指的是多GPU训练,将训练参数分布在多个GPU上进行训练,有利于提升训练效率。
🚀二、执行main()函数
2.1 检查分布式训练环境
def main(opt, callbacks=Callbacks()):
'''
2.1 检查分布式训练环境
'''
# Checks
if RANK in [-1, 0]: # 若进程编号为-1或0
# 输出所有训练参数 / 参数以彩色的方式表现
print_args(FILE.stem, opt)
# 检测YOLO v5的github仓库是否更新,若已更新,给出提示
check_git_status()
# 检查requirements.txt所需包是否都满足
check_requirements(exclude=['thop'])
这段代码主要是检查分布式训练的环境。
若RANK为-1或0,会执行下面三行代码,打印参数并检查github仓库和依赖库。
- 第一行代码,负责打印文件所用到的参数信息,这个参数包括命令行传入进去的参数以及默认参数
- 第二行代码,检查yolov5的github仓库是否更新,如果更新的话,会有一个提示
- 第三行代码,检查requirements中要求的安装包有没有正确安装成功,没有成功的话会给予一定的提示
2.2 判断是否断点续训
'''
2.2 判断是否断点续训
'''
# Resume
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
# isinstance()是否是已经知道的类型
# 如果resume是True,则通过get_lastest_run()函数找到runs为文件夹中最近的权重文件last.pt
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
# 判断是否为文件,若不是文件抛出异常
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
# opt.yaml是训练时的命令行参数文件
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
# 超参数替换,将训练时的命令行参数加载进opt参数对象中
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
# opt.cfg设置为'' 对应着train函数里面的操作(加载权重时是否加载权重里的anchor)
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
# 打印从ckpt恢复断点训练信息
LOGGER.info(f'Resuming training from ckpt')
else:
# 不使用断点续训,就从文件中读取相关参数
# check_file (utils/general.py)的作用为查找/下载文件 并返回该文件的路径。
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \\
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
# 如果模型文件和权重文件为空,弹出警告
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
# 如果要进行超参数进化,重建保存路径
if opt.evolve:
# 设置新的项目输出目录
opt.project = str(ROOT / 'runs/evolve')
# 将resume传递给exist_ok
opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
# 根据opt.project生成目录,并赋值给opt.save_dir 如: runs/train/exp1
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
这段代码主要是关于断点训练的判断和准备。
断点训练是当训练异常终止或想调节超参数时,系统会保留训练异常终止前的超参数与训练参数,当下次训练开始时,并不会从头开始,而是从上次中断的地方继续训练。
- 使用断点续训,就从last.pt中读取相关参数
- 不使用断点续训,就从文件中读取相关参数
2.3 判断是否分布式训练
'''
2.3 判断是否分布式训练
'''
# DDP mode --> 支持多机多卡、分布式训练
# 选择程序装载的位置
device = select_device(opt.device, batch_size=opt.batch_size)
# 当进程内的GPU编号不为-1时,才会进入DDP
if LOCAL_RANK != -1:
# 用于DDP训练的GPU数量不足
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
# WORLD_SIZE表示全局的进程数
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
# 不能使用图片采样策略
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
# 不能使用超参数进化
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
# 设置装载程序设备
torch.cuda.set_device(LOCAL_RANK)
# 保存装载程序的设备
device = torch.device('cuda', LOCAL_RANK)
# torch.distributed是用于多GPU训练的模块
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
这段代码主要是检查DDP训练的配置,并设置GPU。
DDP(Distributed Data Parallel)用于单机或多机的多GPU分布式训练,但目前DDP只能在Linux下使用。这部分它会选择你是使用cpu还是gpu,假如你采用的是分布式训练的话,它就会额外执行下面的一些操作,我们这里一般不会用到分布式,所以也就没有执行什么东西。
2.4 判断是否进化训练
'''
2.4 判断是否进化训练
'''
# Train 训练模式: 如果不进行超参数进化,则直接调用train()函数,开始训练
if not opt.evolve:# 如果不使用超参数进化
# 开始训练
train(opt.hyp, opt, device, callbacks)
if WORLD_SIZE > 1 and RANK == 0:
# 如果全局进程数大于1并且RANK等于0
# 日志输出 销毁进程组
LOGGER.info('Destroying process group... ')
# 训练完毕,销毁所有进程
dist.destroy_process_group()
这段代码是不进行进化训练的情况,此时正常训练。
如果输入evolve会执行else下面这些代码,因为我们没有输入evolve并且不是分布式训练,因此会执行train函数。也就是说,当不使用超参数进化训练时,直接把命令行参数传入train函数,训练完成后销毁所有进程。
接下来我们再看看使用超参数进化训练的情况:
# Evolve hyperparameters (optional) 遗传进化算法,边进化边训练
else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
# 超参数列表(突变范围 - 最小值 - 最大值)
meta = 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'box': (1, 0.02, 0.2), # box loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0) # segment copy-paste (probability)
# 加载默认超参数
with open(opt.hyp, errors='ignore') as f:
hyp = yaml.safe_load(f) # load hyps dict
# 如果超参数文件中没有'anchors',则设为3
if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3
# 使用进化算法时,仅在最后的epoch测试和保存
opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
if opt.bucket:
os.system(f'gsutil cp gs://opt.bucket/evolve.csv save_dir') # download evolve.csv if exists
"""
遗传算法调参:遵循适者生存、优胜劣汰的法则,即寻优过程中保留有用的,去除无用的。
遗传算法需要提前设置4个参数: 群体大小/进化代数/交叉概率/变异概率
"""
这段代码是使用超参数进化训练的前期准备
首先指定每个超参数的突变范围、最大值、最小值,再为超参数的结果保存做好准备。
# 选择超参数的遗传迭代次数 默认为迭代300次
for _ in range(opt.evolve): # generations to evolve
# 如果evolve.csv文件存在
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
# 选择超参进化方式,只用single和weighted两种
parent = 'single' # parent selection method: 'single' or 'weighted'
# 加载evolve.txt
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
# 选取至多前五次进化的结果
n = min(5, len(x)) # number of previous results to consider
# fitness()为x前四项加权 [P, R, mAP@0.5, mAP@0.5:0.95]
# np.argsort只能从小到大排序, 添加负号实现从大到小排序, 算是排序的一个代码技巧
x = x[np.argsort(-fitness(x))][:n] # top n mutations
# 根据(mp, mr, map50, map)的加权和来作为权重计算hyp权重
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
# 根据不同进化方式获得base hyp
if parent == 'single' or len(x) == 1:
# 根据权重的几率随机挑选适应度历史前5的其中一个
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
# 对hyp乘上对应的权重融合层一个hpy, 再取平均(除以权重和)
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate 突变(超参数进化)
mp, s = 0.8, 0.2 # mutation probability, sigma:突变概率
npr = np.random
# 根据时间设置随机数种子
npr.seed(int(time.time()))
# 获取突变初始值, 也就是meta三个值的第一个数据
# 三个数值分别对应着: 变异初始概率, 最低限值, 最大限值(mutation scale 0-1, lower_limit, upper_limit)
g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
ng = len(meta)
# 确保至少其中有一个超参变异了
v = np.ones(ng)
# 设置突变
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
# 将突变添加到base hyp上
for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
hyp[k] = float(x[i + 7] * v[i]) # mutate
# Constrain to limits 限制hyp在规定范围内
for k, v in meta.items():
# 这里的hyp是超参数配置文件对象
# 而这里的k和v是在元超参数中遍历出来的
# hyp的v是一个数,而元超参数的v是一个元组
hyp[k] = max(hyp[k], v[1]) # 先限定最小值,选择二者之间的大值 ,这一步是为了防止hyp中的值过小
hyp[k] = min(hyp[k], v[2]) # 再限定最大值,选择二者之间的小值
hyp[k] = round(hyp[k], 5) # 四舍五入到小数点后五位
# 最后的值应该是 hyp中的值与 meta的最大值之间的较小者
# Train mutation 使用突变后的参超,测试其效果
results = train(hyp.copy(), opt, device, callbacks)
# Write mutation results
# 将结果写入results,并将对应的hyp写到evolve.txt,evolve.txt中每一行为一次进化的结果
# 每行前七个数字 (P, R, mAP, F1, test_losses(GIOU, obj, cls)) 之后为hyp
# 保存hyp到yaml文件
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
# Plot results 将结果可视化 / 输出保存信息
plot_evolve(evolve_csv)
LOGGER.info(f'Hyperparameter evolution finished\\n'
f"Results saved to colorstr('bold', save_dir)\\n"
f'Use best hyperparameters example: $ python train.py --hyp evolve_yaml')
这段代码是开始超参数进化训练。
超参数进化的步骤如下:
- 1.若存在evolve.csv文件,读取文件中的训练数据,选择超参进化方式,结果最优的训练数据突变超参数
- 2.限制超参进化参数hyp在规定范围内
- 3.使用突变后的超参数进行训练,测试其效果
- 4.训练结束后,将训练结果可视化,输出保存信息保存至evolution.csv,用于下一次的超参数突变。
原理:根据生物进化,优胜劣汰,适者生存的原则,每次迭代都会保存更优秀的结果,直至迭代结束。最后的结果即为最优的超参数
注意:使用超参数进化时要经过至少300次迭代,每次迭代都会经过一次完整的训练。因此超参数进化及其耗时,大家需要根据自己需求慎用。
🚀三、设置opt参数
=============================================三、设置opt参数==================================================='''
def parse_opt(known=False):
parser = argparse.ArgumentParser()
# 预训练权重文件
parser.add_argument('--weights', type=str, default=ROOT / 'pretrained/yolov5s.pt', help='initial weights path')
# 训练模型
parser.add_argument('--cfg', type=str, default=ROOT / 'models/yolov5s.yaml', help='model.yaml path')
# 训练路径,包括训练集,验证集,测试集的路径,类别总数等
parser.add_argument('--data', type=str, default=ROOT / 'data/fire_data.yaml', help='dataset.yaml path')
# hpy超参数设置文件(lr/sgd/mixup)./data/hyps/下面有5个超参数设置文件,每个文件的超参数初始值有细微区别,用户可以根据自己的需求选择其中一个
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
# epochs: 训练轮次, 默认轮次为300次
parser.add_argument('--epochs', type=int, default=300)
# batchsize: 训练批次, 默认bs=16
parser.add_argument('--batch-size', type=int, default=4, help='total batch size for all GPUs, -1 for autobatch')
# imagesize: 设置图片大小, 默认640*640
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
# rect: 是否采用矩形训练,默认为False
parser.add_argument('--rect', action='store_true', help='rectangular training')
# resume: 是否接着上次的训练结果,继续训练
# 矩形训练:将比例相近的图片放在一个batch(由于batch里面的图片shape是一样的)
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
# nosave: 不保存模型 默认False(保存) 在./runs/exp*/train/weights/保存两个模型 一个是最后一次的模型 一个是最好的模型
# best.pt/ last.pt 不建议运行代码添加 --nosave
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
# noval: 最后进行测试, 设置了之后就是训练结束都测试一下, 不设置每轮都计算mAP, 建议不设置
parser.add_argument('--noval', action='store_true', help='only validate final epoch')
# noautoanchor: 不自动调整anchor, 默认False, 自动调整anchor
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
# evolve: 参数进化, 遗传算法调参
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
# bucket: 谷歌优盘 / 一般用不到
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
# cache: 是否提前缓存图片到内存,以加快训练速度,默认False
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
# mage-weights: 使用图片采样策略,默认不使用
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
# device: 设备选择
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
# parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
# multi-scale 是否进行多尺度训练
parser.add_argument('--multi-scale', default=True, help='vary img-size +/- 50%%')
# single-cls: 数据集是否多类/默认True
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
# optimizer: 优化器选择 / 提供了三种优化器
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
# sync-bn: 是否使用跨卡同步BN,在DDP模式使用
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
# dataloader的最大worker数量 (使用多线程加载图片)
parser.add_argument('--workers', type=int, default=0, help='max dataloader workers (per RANK in DDP mode)')
# 训练结果的保存路径
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
# 训练结果的文件名称
parser.add_argument('--name', default='exp', help='save to project/name')
# 项目位置是否存在 / 默认是都不存在
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
# 四元数据加载器: 允许在较低 --img 尺寸下进行更高 --img 尺寸训练的一些好处。
parser.add_argument('--quad', action='store_true', help='quad dataloader')
# cos-lr: 余弦学习率
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
# 标签平滑 / 默认不增强, 用户可以根据自己标签的实际情况设置这个参数,建议设置小一点 0.1 / 0.05
parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
# 早停止耐心次数 / 100次不更新就停止训练
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
# --freeze冻结训练 可以设置 default = [0] 数据量大的情况下,建议不设置这个参数
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
# --save-period 多少个epoch保存一下checkpoint
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
# --local_rank 进程编号 / 多卡使用
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
# Weights & Biases arguments
# 在线可视化工具,类似于tensorboard工具
parser.add_argument('--entity', default=None, help='W&B: Entity')
# upload_dataset: 是否上传dataset到wandb tabel(将数据集作为交互式 dsviz表 在浏览器中查看、查询、筛选和分析数据集) 默认False
parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
# bbox_interval: 设置界框图像记录间隔 Set bounding-box image logging interval for W&B 默认-1 opt.epochs // 10
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
# 使用数据的版本
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
# 作用就是当仅获取到基本设置时,如果运行命令中传入了之后才会获取到的其他配置,不会报错;而是将多出来的部分保存起来,留到后面使用
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt
opt参数解析:
- cfg: 模型配置文件,网络结构
- data: 数据集配置文件,数据集路径,类名等
- hyp: 超参数文件
- epochs: 训练总轮次
- batch-size: 批次大小
- img-size: 输入图片分辨率大小
- rect: 是否采用矩形训练,默认False
- resume: 接着打断训练上次的结果接着训练
- nosave: 不保存模型,默认False
- notest: 不进行test,默认False
- noautoanchor: 不自动调整anchor,默认False
- evolve: 是否进行超参数进化,默认False
- bucket: 谷歌云盘bucket,一般不会用到
- cache-images: 是否提前缓存图片到内存,以加快训练速度,默认False
- weights: 加载的权重文件
- name: 数据集名字,如果设置:results.txt to results_name.txt,默认无
- device: 训练的设备,cpu;0(表示一个gpu设备cuda:0);0,1,2,3(多个gpu设备)
- multi-scale: 是否进行多尺度训练,默认False
- single-cls: 数据集是否只有一个类别,默认False
- adam: 是否使用adam优化器
- sync-bn: 是否使用跨卡同步BN,在DDP模式使用
- local_rank: gpu编号
- logdir: 存放日志的目录 workers: dataloader的最大worker数量
(关于调参,推荐大家看@迪菲赫尔曼大佬的这篇文章:手把手带你调参YOLOv5 (v5.0-v7.0)(验证))
🚀四、执行train()函数
4.1 加载参数和初始化配置信息
4.1.1 载入参数
''' =====================1.载入参数和初始化配置信息========================== '''
'''
1.1 载入参数
'''
def train(hyp, # 超参数 可以是超参数配置文件的路径或超参数字典 path/to/hyp.yaml or hyp
opt, # main中opt参数
device, # 当前设备
callbacks # 用于存储Loggers日志记录器中的函数,方便在每个训练阶段控制日志的记录情况
):
# 从opt获取参数。日志保存路径,轮次、批次、权重、进程序号(主要用于分布式训练)等
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \\
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \\
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
这段代码是接收传来的参数。
- hyp: 超参数,不使用超参数进化的前提下也可以从opt中获取
- opt: 全部的命令行参数
- device: 指的是装载程序的设备
- callbacks: 指的是训练过程中产生的一些参数
4.1.2 创建训练权重目录和保存路径
'''
1.2 创建训练权重目录,设置模型、txt等保存的路径
'''
# Directories 获取记录训练日志的保存路径
# 设置保存权重路径 如runs/train/exp1/weights
w = save_dir / 'weights' # weights dir
# 新建文件夹 weights train evolve
(w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
# 保存训练结果的目录,如last.pt和best.pt
last, best = w / 'last.pt', w / 'best.pt'
这段代码主要是创建权重文件保存路径,权重名字和训练日志txt文件
每次训练结束后,系统会产生两个模型,一个是last.pt,一个是best.pt。顾名思义,last.pt即为训练最后一轮产生的模型,而best.pt是训练过程中,效果最好的模型。
然后创建文件夹,保存训练结果的模型文件路径 以及验证集输出结果的txt文件路径,包含迭代的次数,占用显存大小,图片尺寸,精确率,召回率,位置损失,类别损失,置信度损失和map等。
4.1.3 读取超参数配置文件
'''
1.3 读取hyp(超参数)配置文件
'''
# Hyperparameters 加载超参数
if isinstance(hyp, str): # isinstance()是否是已知类型。 判断hyp是字典还是字符串
# 若hyp是字符串,即认定为路径,则加载超参数为字典
with open(hyp, errors='ignore') as f:
# 加载yaml文件
hyp = yaml.safe_load(f) # load hyps dict 加载超参信息
# 打印超参数 彩色字体
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'k=v' for k, v in hyp.items()))
这段代码主要是加载一些训练过程中需要使用的超参数,并打印出来
首先,检查超参数是字典还是字符串,若为字符串,则认定为.yaml文件路径,再将yaml文件加载为字典。这里导致超参数的数据类型不同的原因是,超参数进化时,传入train()函数的超参数即为字典。而从命令行参数中读取的则为文件路径。
然后将打印这些超参数。
4.1.4 设置参数的保存路径
'''
1.4 设置参数的保存路径
'''
# Save run settings 保存训练中的参数hyp和opt
with open(save_dir / 'hyp.yaml', 'w') as f:
# 保存超参数为yaml配置文件
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
# 保存命令行参数为yaml配置文件
yaml.safe_dump(vars(opt), f, sort_keys=False)
# 定义数据集字典
data_dict = None
这段代码主要是将训练的相关参数全部写入
将本次运行的超参数(hyp)和选项操作(opt)给保存成yaml格式,保存在了每次训练得到的exp文件中,这两个yaml显示了我们本次训练所选择的hyp超参数和opt参数。
还有一点,yaml.safe_load(f)是加载yaml的标准函数接口,保存超参数为yaml配置文件。 yaml.safe_dump()是将yaml文件序列化,保存命令行参数为yaml配置文件。vars(opt)
的作用是把数据类型是Namespace的数据转换为字典的形式。
4.1.5 加载日志信息
'''
1.5 加载相关日志功能:如tensorboard,logger,wandb
'''
# Loggers 设置wandb和tb两种日志, wandb和tensorboard以上是关于YOLOv5 最详细的源码逐行解读的主要内容,如果未能解决你的问题,请参考以下文章
YOLOv5源码逐行超详细注释与解读——训练部分train.py
YOLOv5源码逐行超详细注释与解读——网络结构yolo.py