实现混合精度的两种方式

Posted AI浩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了实现混合精度的两种方式相关的知识,希望对你有一定的参考价值。

1、使用apex

APEX是英伟达开源的,完美支持PyTorch框架,用于改变数据格式来减小模型显存占用的工具。其中最有价值的是amp(Automatic Mixed Precision),将模型的大部分操作都用Float16数据类型测试,一些特别操作仍然使用Float32。并且用户仅仅通过三行代码即可完美将自己的训练代码迁移到该模型。实验证明,使用Float16作为大部分操作的数据类型,并没有降低参数,在一些实验中,反而由于可以增大Batch size,带来精度上的提升,以及训练速度上的提升。

使用方式

在混合精度训练上,Apex 的封装十分优雅。直接使用 amp.initialize 包装模型和优化器,apex 就会自动帮助我们管理模型参数和优化器的精度了,根据精度需求不同可以传入其他配置参数。

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        output = model(data)
        loss = criterion_train(output, targets)
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
         else:
            loss.backward()

其中 opt_level 为精度的优化设置,O0(第一个字母是大写字母O):

O0:纯FP32训练,可以作为accuracy的baseline;
O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
O3:纯FP16训练,很不稳定,但是可以作为speed的baselin

2、pytorch自带函数

PyTorch 1.6的发行版包含了对PyTorch进行自动混合精度训练的本地实现。这里的主要思想是,与在所有地方都使用单精度(FP32)相比,某些操作可以在半精度(FP16)下运行得更快,而且不会损失精度。然后,AMP自动决定应该以何种格式执行何种操作。这允许更快的训练和更小的内存占用。

使用方法:

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()
for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.cuda.amp.autocast():
      loss = model(data)
   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()
   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)
   # Updates the scale for next iteration
   scaler.update()

以上是关于实现混合精度的两种方式的主要内容,如果未能解决你的问题,请参考以下文章

Question20180128 补位的两种方式

正则表达式验证6~30位数字,下划线,中划线,字母任意两种混合的密码验证策略

混合精度训练

React 组件的两种创建方式

iOS星级评价的两种实现方式

Android中Fragment与Activity之间的交互(两种实现方式)