MMEngine理解
Posted Arrow
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了MMEngine理解相关的知识,希望对你有一定的参考价值。
MMEngine理解
- 1 简介
- 2 上手示例
- 3. 基础模块
- 4. 高级模块
- 5. 恢复训练
- 6. 加速训练
- 7. 节省显存
- 8. 跨库调用模块
1 简介
- MMEngine 是一个用于深度学习模型训练的基础库,基于 PyTorch,支持在 Linux、Windows、macOS 上运行。它具有如下三个亮点:
- 通用:MMEngine 实现了一个高级的通用训练器,它能够:
- 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 imagenet(pytorch example 400 行)
- 轻松兼容流行的算法库如 TIMM、TorchVision 和 Detectron2 中的模型
- 统一:MMEngine 设计了一个接口统一的开放架构,使得
- 用户可以仅依赖一份代码实现所有任务的轻量化,例如 MMRazor 1.x 相比 MMRazor 0.x 优化了 40% 的代码量
- 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
- 灵活:MMEngine 实现了“乐高”式的训练流程,支持了
- 根据迭代数、 loss 和评测结果等动态调整的训练流程、优化策略和数据增强策略,例如早停(early stopping)机制等
- 任意形式的模型权重平均,如 Exponential Momentum Average (EMA) 和 Stochastic Weight Averaging (SWA)
- 训练过程中针对任意数据和任意节点的灵活可视化和日志控制
- 对神经网络模型中各个层的优化配置进行细粒度调整
- 混合精度训练的灵活控制
1.1 架构
- 上图展示了 MMEngine 在 OpenMMLab 2.0 中的层次。MMEngine 实现了 OpenMMLab 算法库的新一代训练架构,为 OpenMMLab 中的 30 多个算法库提供了统一的执行基座。其核心组件包含训练引擎、评测引擎和模块管理等。
1.2 模块介绍
- MMEngine 将训练过程中涉及的组件和它们的关系进行了抽象,如上图所示。不同算法库中的同类型组件具有相同的接口定义。
1.2.1 核心模块与相关组件
- 训练引擎的核心模块是执行器(Runner)。 执行器负责执行训练、测试和推理任务并管理这些过程中所需要的各个组件。在训练、测试、推理任务执行过程中的特定位置,执行器设置了钩子(Hook) 来允许用户拓展、插入和执行自定义逻辑。执行器主要调用如下组件来完成训练和推理过程中的循环:
- 数据集(Dataset):负责在训练、测试、推理任务中构建数据集,并将数据送给模型。实际使用过程中会被数据加载器(DataLoader)封装一层,数据加载器会启动多个子进程来加载数据。
- 模型(Model):在训练过程中接受数据并输出 loss;在测试、推理任务中接受数据,并进行预测。分布式训练等情况下会被模型的封装器(Model Wrapper,如MMDistributedDataParallel)封装一层。
- 优化器封装(Optimizer):优化器封装负责在训练过程中执行反向传播优化模型,并且以统一的接口支持了混合精度训练和梯度累加。
- 参数调度器(Parameter Scheduler):训练过程中,对学习率、动量等优化器超参数进行动态调整。
- 在训练间隙或者测试阶段,评测指标与评测器(Metrics & Evaluator)会负责对模型性能进行评测。其中评测器负责基于数据集对模型的预测进行评估。评测器内还有一层抽象是评测指标,负责计算具体的一个或多个评测指标(如召回率、正确率等)。
- 在训练、推理执行过程中,上述各个组件都可以调用日志管理模块和可视化器进行结构化和非结构化日志的存储与展示。日志管理(Logging Modules):负责管理执行器运行过程中产生的各种日志信息。其中消息枢纽 (MessageHub)负责实现组件与组件、执行器与执行器之间的数据共享,日志处理器(Log Processor)负责对日志信息进行处理,处理后的日志会分别发送给执行器的日志器(Logger)和可视化器(Visualizer)进行日志的管理与展示。可视化器(Visualizer):可视化器负责对模型的特征图、预测结果和训练过程中产生的结构化日志进行可视化,支持 Tensorboard 和 WanDB 等多种可视化后端。
1.2.1 公共基础模块
- MMEngine 中还实现了各种算法模型执行过程中需要用到的公共基础模块,包括:
- 配置类(Config):在 OpenMMLab 算法库中,用户可以通过编写 config 来配置训练、测试过程以及相关的组件。
- 注册器(Registry):负责管理算法库中具有相同功能的模块。MMEngine 根据对算法库模块的抽象,定义了一套根注册器,算法库中的注册器可以继承自这套根注册器,实现模块的跨算法库调用。
- 文件读写(File I/O):为各个模块的文件读写提供了统一的接口,以统一的形式支持了多种文件读写后端和多种文件格式,并具备扩展性。
- 分布式通信原语(Distributed Communication Primitives):负责在程序分布式运行过程中不同进程间的通信。这套接口屏蔽了分布式和非分布式环境的区别,同时也自动处理了数据的设备和通信后端。
- 其他工具(Utils):还有一些工具性的模块,如 ManagerMixin,它实现了一种全局变量的创建和获取方式,执行器内很多全局可见对象的基类就是 ManagerMixin。
2 上手示例
- 以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、 可配置的训练和验证流程
2.1 构建模型
- 首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 “loss”,并返回一个包含 “loss” 字段的字典;对于验证,我们需要 mode 接受字符串 “predict”,并返回同时包含预测信息和真实信息的结果。
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return 'loss': F.cross_entropy(x, labels)
elif mode == 'predict':
return x, labels
2.2 构建数据集和数据加载器
- 其次,我们需要构建训练和验证所需要的数据集 (Dataset)和数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
2.3 构建评测指标
- 为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 process 和 compute_metrics 方法。其中 process 方法接受数据集的输出和模型 mode=“predict” 时的输出,此时的数据为一个批次的数据,对这一批次的数据进行处理后,保存信息至 self.results 属性。 而 compute_metrics 接受 results 参数,这一参数的输入为 process 中保存的所有信息 (如果是分布式环境,results 中为已收集的,包括各个进程 process 保存信息的结果),利用这些信息计算并返回保存有评测指标结果的字典。
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# 将一个批次的中间结果保存至 `self.results`
self.results.append(
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
)
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# 返回保存有评测指标结果的字典,其中键为指标名称
return dict(accuracy=100 * total_correct / total_size)
2.4 构建执行器并执行任务
- 最后,我们利用构建好的模型,数据加载器,评测指标构建一个执行器 (Runner),同时在其中配置 优化器、工作路径、训练与验证配置等选项,即可通过调用 train() 接口启动训练:
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
# 用以训练和验证的模型,需要满足特定的接口需求
model=MMResNet50(),
# 工作路径,用以保存训练日志、权重文件信息
work_dir='./work_dir',
# 训练数据加载器,需要满足 PyTorch 数据加载器协议
train_dataloader=train_dataloader,
# 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# 训练配置,用于指定训练周期、验证间隔等信息
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
# 验证数据加载器,需要满足 PyTorch 数据加载器协议
val_dataloader=val_dataloader,
# 验证配置,用于指定验证所需要的额外参数
val_cfg=dict(),
# 用于验证的评测器,这里使用默认评测器,并评测指标
val_evaluator=dict(type=Accuracy),
)
runner.train()
3. 基础模块
3.1 注册器(Registry)
- OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 ResNet 和 SEResNet 类,这些类有相似的功能和接口,都属于算法库中的模型组件。 为了管理这些功能相似的模块,MMEngine 实现了 注册器。 OpenMMLab 大多数算法库均使用注册器来管理它们的代码模块,包括 MMDetection, MMDetection3D,MMPose, MMClassification 和 MMEditing 等。
3.1.1 什么是注册器
- MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。
- 映射表:维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 “ResNet” 到 ResNet 类或函数的映射,使得用户可以通过 “ResNet” 找到 ResNet 类;
- 模块构建方法:定义了如何根据字符串查找到对应的类或函数以及如何实例化这个类或者调用这个函数,例如,通过字符串 “bn” 找到 nn.BatchNorm2d 并实例化 BatchNorm2d 模块;又或者通过字符串 “build_batchnorm2d” 找到 build_batchnorm2d 函数并返回该函数的调用结果。
- MMEngine 中的注册器默认使用 build_from_cfg 函数来查找并实例化字符串对应的类或者函数。
- 一个注册器:管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 MODELS 可以被视作所有模型的抽象,管理了 ResNet, SEResNet 和 RegNetX 等分类网络的类以及 build_ResNet, build_SEResNet 和 build_RegNetX 等分类网络的构建函数。
- 注册器的定义(部分代码)
class Registry:
"""A registry to map strings to classes or functions.
Registered object could be built from registry. Meanwhile, registered
functions could be called from registry.
Args:
name (str): Registry name.
build_func (callable, optional): A function to construct instance
from Registry. :func:`build_from_cfg` is used if neither ``parent``
or ``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Defaults to None.
parent (:obj:`Registry`, optional): Parent registry. The class
registered in children registry could be built from parent.
Defaults to None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Defaults to None.
Examples:
>>> # define a registry
>>> MODELS = Registry('models')
>>> # registry the `ResNet` to `MODELS`
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> # build model from `MODELS`
>>> resnet = MODELS.build(dict(type='ResNet'))
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = MODELS.build(dict(type='resnet50'))
>>> # hierarchical registry
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
>>> @DETECTORS.register_module()
>>> class FasterRCNN:
>>> pass
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
More advanced usages can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
def __init__(self,
name: str,
build_func: Optional[Callable] = None,
parent: Optional['Registry'] = None,
scope: Optional[str] = None):
from .build_functions import build_from_cfg
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._children: Dict[str, 'Registry'] = dict()
if scope is not None:
assert isinstance(scope, str)
self._scope = scope
else:
self._scope = self.infer_scope()
# See https://mypy.readthedocs.io/en/stable/common_issues.html#
# variables-vs-type-aliases for the use
self.parent: Optional['Registry']
if parent is not None:
assert isinstance(parent, Registry)
parent._add_child(self)
self.parent = parent
else:
self.parent = None
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
self.build_func: Callable
if build_func is None:
if self.parent is not None:
self.build_func = self.parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
3.1.2 使用流程
- 使用注册器管理代码库中的模块,需要以下三个步骤:
- 创建注册器
- 创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)
- 将模块加入注册器中
- 假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。
3.1.2.1 创建注册器
from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
ACTIVATION = Registry('activation', scope='mmengine')
3.1.2.2 定义要注册的模块(类或函数)
import torch.nn as nn
# 使用注册器管理模块
@ACTIVATION.register_module()
class Sigmoid(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print('call Sigmoid.forward')
return x
@ACTIVATION.register_module()
class ReLU(nn.Module):
def __init__(self, inplace=False):
super().__init__()
def forward(self, x):
print('call ReLU.forward')
return x
@ACTIVATION.register_module()
class Softmax(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print('call Softmax.forward')
return x
- 使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 ACTIVATION 中。通过 @ACTIVATION.register_module() 装饰所实现的模块,字符串和类或函数之间的映射就可以由 ACTIVATION 构建和维护,我们也可以通过 ACTIVATION.register_module(module=ReLU) 实现同样的功能。
- 通过注册,我们就可以通过 ACTIVATION 建立字符串与类或函数之间的映射:
print(ACTIVATION.module_dict)
#
# 'Sigmoid': __main__.Sigmoid,
# 'ReLU': __main__.ReLU,
# 'Softmax': __main__.Softmax
#
- 只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 custom_imports 字段动态导入该模块进而触发注册机制,详情见导入自定义 Python 模块。
3.1.2.3 通过配置激活模块
- 模块成功注册后,我们可以通过配置文件使用这个激活模块。
import torch
input = torch.randn(2)
act_cfg = dict(type='Sigmoid')
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call Sigmoid.forward
print(output)
#如果我们想使用 ReLU,仅需修改配置。
act_cfg = dict(type='ReLU', inplace=True)
activation = ACTIVATION.build(act_cfg)
output = activation(input)
# call ReLU.forward
print(output)
3.1.3 跨项目调用
- MMEngine 的注册器支持层级注册,利用该功能可实现跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。
- 为了方便跨库调用,MMEngine 提供了 20 个根注册器:
- RUNNERS: Runner 的注册器
- RUNNER_CONSTRUCTORS: Runner 的构造器
- LOOPS: 管理训练、验证以及测试流程,如 EpochBasedTrainLoop
- HOOKS: 钩子,如 CheckpointHook, ParamSchedulerHook
- DATASETS: 数据集
- DATA_SAMPLERS: DataLoader 的 Sampler,用于采样数据
- TRANSFORMS: 各种数据预处理,如 Resize, Reshape
- MODELS: 模型的各种模块
- MODEL_WRAPPERS: 模型的包装器,如 MMDistributedDataParallel,用于对分布式数据并行
- WEIGHT_INITIALIZERS: 权重初始化的工具
- OPTIMIZERS: 注册了 PyTorch 中所有的 Optimizer 以及自定义的 Optimizer
- OPTIM_WRAPPER: 对 Optimizer 相关操作的封装,如 OptimWrapper,AmpOptimWrapper
- OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper 的构造器
- PARAM_SCHEDULERS: 各种参数调度器,如 MultiStepLR
- METRICS: 用于计算模型精度的评估指标,如 Accuracy
- EVALUATOR: 用于计算模型精度的一个或多个评估指标
- TASK_UTILS: 任务强相关的一些组件,如 AnchorGenerator, BboxCoder
- VISUALIZERS: 管理绘制模块,如 DetVisualizer 可在图片上绘制预测框
- VISBACKENDS: 存储训练日志的后端,如 LocalVisBackend, TensorboardVisBackend
- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 LogProcessor,如有特殊需求可自定义 LogProcessor
3.2 配置(Config)
- MMEngine 实现了抽象的配置类(Config),为用户提供统一的配置访问接口。配置类能够支持不同格式的配置文件,包括 python,json,yaml,用户可以根据需求选择自己偏好的格式。配置类提供了类似字典或者 Python 对象属性的访问接口,用户可以十分自然地进行配置字段的读取和修改。为了方便算法框架管理配置文件,配置类也实现了一些特性,例如配置文件的字段继承等。
3.2.1 配置文件读取
- 配置类提供了统一的接口 Config.fromfile(),来读取和解析配置文件。
- 合法的配置文件应该定义一系列键值对,这里举几个不同格式配置文件的例子。
- Python 格式:
test_int = 1
test_list = [1, 2, 3]
test_dict = dict(key1='value1', key2=0.1)
- Json 格式:
"test_int": 1,
"test_list": [1, 2, 3],
"test_dict": "key1": "value1", "key2": 0.1
- YAML 格式:
test_int: 1
test_list: [1, 2, 3]
test_dict:
key1: "value1"
key2: 0.1
- 对于以上三种格式的文件,假设文件名分别为 config.py,config.json,config.yml,调用 Config.fromfile(‘config.xxx’) 接口加载这三个文件都会得到相同的结果,构造了包含 3 个字段的配置对象。我们以 config.py 为例,我们先将示例配置文件下载到本地:
from mmengine.config import Config
cfg = Config.fromfile('learn_read_config.py')
print(cfg)
- 输出结果为:
Config (path: learn_read_config.py): 'test_int': 1, 'test_list': [1, 2, 3], 'test_dict': 'key1': 'value1', 'key2': 0.1
3.2.2 配置文件的使用
- 通过读取配置文件来初始化配置对象后,就可
以上是关于MMEngine理解的主要内容,如果未能解决你的问题,请参考以下文章