Pytorch分布式训练与断点续训

Posted 洪流之源

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch分布式训练与断点续训相关的知识,希望对你有一定的参考价值。

1. Pytorch分布式训练

Pytorch支持多机多卡分布式训练,参与分布式训练的机器用Node表述(Node不限定是物理机器,还是容器,例如docker,一个Node节点就是一台机器),Node又分为Master Node、Slave Node,Master Node只有一个,Slave Node可以有多个,假定现在有两台机器参与分布式训练,每台机器有4张显卡,分别在两台机器上执行如下命令(以yolov5训练为例):

Master Node执行如下命令:

python -m torch.distributed.launch \\
       --nnodes 2 \\
       --nproc_per_node 4 \\
       --use_env \\
       --node_rank 0 \\
       --master_addr "192.168.1.2" \\
       --master_port 1234 \\
       train.py \\
       --batch 64 \\
       --data coco.yaml \\
       --cfg yolov5s.yaml \\
       --weights 'yolov5s.pt'

 Slave Node执行如下命令:

python -m torch.distributed.launch \\
       --nnodes 2 \\
       --nproc_per_node 4 \\        
       --use_env \\
       --node_rank 1 \\
       --master_addr "192.168.1.2" \\
       --master_port 1234 train.py \\
       --batch 64 \\
       --data coco.yaml \\
       --cfg yolov5s.yaml \\
       --weights 'yolov5s.pt

上述的命令中:

--nnodes:表示一共多少台机器参与分布式训练,也就是有几个Node,只有2台机器所以设置为2
--nproc_per_node:表示每台机器有多少张显卡,每台机器有4张显卡所以设置为4
--node_rank:表示当前机器的序号,一般设置为0的作为Master Node
--master_add:表示Master Node的ip地址
--master_add:表示Master Node的端口号


一般情况下,当两台机器都运行完命令,训练就开始了,否则master处于等待状态,直到slave节点也就绪,分布式训练才会开始。

如下所示,训练开始后Node节点的分配情况:

Node 0
    Process0 [Global Rank=0, Local Rank=0] -> GPU 0
    Process1 [Global Rank=1, Local Rank=1] -> GPU 1
    Process2 [Global Rank=2, Local Rank=2] -> GPU 2
    Process3 [Global Rank=3, Local Rank=3] -> GPU 3
Node 1
    Process4 [Global Rank=4, Local Rank=0] -> GPU 0
    Process5 [Global Rank=5, Local Rank=1] -> GPU 1
    Process6 [Global Rank=6, Local Rank=2] -> GPU 2
    Process7 [Global Rank=7, Local Rank=3] -> GPU 3

每一张显卡被分配一个进程,从Process0 ~ Process7,Global Rank表示在整个分布式训练任务中的分布式进程编号,从Global Rank = 0 ~ Global Rank = 7

Local Rank表示在某个Node内的编号,在Node 0中Local Rank = 0 ~ Local Rank = 3。另外,如果只写rank前面没有global、local等字段一般指代Global Rank。

world_size表示全局进程数量,也就是分布式进程的数量,在上述的配置中world_size = 8,如果一共有3个node(nnodes=3),每个node包含8个GPU,设置nproc_per_node=4,world_size就是3 * 4 = 12,为什么不是3 * 8 = 24呢?因为每个node虽然有8个GPU,但是命令设置只使用其中4个(nproc_per_node=4),有而不使用是不算数的。
 

训练算法的时候如果要使用分布式训练,需要对训练流程添加分布式的支持,主要有如下步骤:

1) 初始化分布式进程环境;

2)对数据集构建分布式采样器;

3)对网络模型用DistributedDataParallel进行包装;

4)日志与模型保存在主进程中进行;

5)对loss、评估指标等数据进行all_reduce同步;

6)如果网络中存在BN层,可开启BN同步

7)如果多机进行分布式训练,需要保证Node直接网络互通。

更多分布式训练的知识可参考:(471条消息) Pytorch中多GPU并行计算教程_太阳花的小绿豆的博客-CSDN博客_pytorch 多gpu

2. 断点续训

断点续训比较简单,在训练的过程中需要在checkpoints中保存能够回复模型训练的数据,主要包括:模型权重参数、优化器、学习率调度器、当前训练的轮次(比如epoch数),另外,也可以保存评估指标、训练参数等数据。在恢复训练的时候将上述数据从checkpoint中取出,从当前状态继续训练。

3. 分布式训练与断点续训代码示例

dist_classification_train.py

import argparse
import os
import numpy as np
import random
import tempfile
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torchvision import models
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from warmup_scheduler import GradualWarmupScheduler
from datetime import datetime
import sys

def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"]) # RANK表示global rank, 表示当前进程在所有进程中的序号
        args.world_size = int(os.environ['WORLD_SIZE']) # WORLD_SIZE表示分布式训练的所有的进程数量
        args.gpu = int(os.environ['LOCAL_RANK']) # 在当前分布式节点中进程的编号,因为一个显卡分配一个进程,因此这个编号也是当前进程运行的GPU编号
    elif 'SLURM_PROCID' in os.environ: # 在SLURM集群上使用
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu) # 指定当前进程所使用的的GPU
    args.dist_backend = 'nccl'  # 通信后端,nvidia GPU推荐使用NCCL
    print('| distributed init (rank ): '.format(args.rank, args.dist_url), flush=True)
    
    # 初始化分布式进程组
    dist.init_process_group(backend=args.dist_backend, 
                            init_method=args.dist_url, # 指的是如何初始化,以完成刚开始的进程同步,'file://' or 'tcp://' or 'env://',默认采用环境变量即'env://'
                            world_size=args.world_size, 
                            rank=args.rank)
    dist.barrier() # 同步所有的进程, 也就是所有节点的所有GPU运行至这个函数的时候, 才会执行后面的代码

def is_dist_avail_and_initialized():
    """检查是否支持分布式环境"""
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()

def is_main_process():
    return get_rank() == 0

def reduce_value(value, average=True):
    world_size = get_world_size()
    if world_size < 2:  # 单GPU的情况
        return value

    with torch.no_grad():
        dist.all_reduce(value) # 所有参与训练GPU设备的value值的总和
        if average:
            value /= world_size # 对value值求均值,得到每个GPU上的的平均value

        return value

def train_one_epoch(model, optimizer, data_loader, loss_function, device, epoch, args):
    model.train()
    optimizer.zero_grad()
    mean_loss = torch.zeros(1).to(device)

    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data
        outputs = model(images.to(device))

        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        loss = reduce_value(loss, average=True) # 得到多个GPU计算loss的均值
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses

        # 如果loss是非法值,则退出
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()

        # 如果是主进程打印平均loss
        if is_main_process():
            data_loader.desc = " Training:Epoch[:0>3/:0>3] Iteration[:0>3/:0>3] Loss: :.4f".format(
                                datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, step + 1, len(data_loader), 
                                round(loss.item(), 3))
    
    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    return mean_loss.item()

@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)

    # 如果是主进程打印验证进度
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)
        
    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()
        data_loader.desc = " Validating:Iteration[:0>3/:0>3]".format(
                           datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 
                           step + 1, len(data_loader))

    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    sum_num = reduce_value(sum_num, average=False) # 求和得到在所有GPU上正确样本个数

    return sum_num.item()

def main(args):
    if not torch.cuda.is_available():
        raise EnvironmentError("not find GPU device for training.")

    # 初始化分布式进程环境
    init_distributed_mode(args=args)

    rank = args.rank

    device = torch.device(args.device)

    # 相同的随机种子seed将模型在初始化过程中所用到的“随机数”全部固定下来,
    # 以保证每次重新训练模型需要初始化模型参数的时候能够得到相同的初始化参数,
    # 从而达到稳定复现训练结果的目的
    seed = args.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    batch_size = args.batch_size
    weights_path = args.pretrained_weights

    if rank == 0:  # 如果是master node的一个进程
        print(args)
        print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
        summary_writer = SummaryWriter()
        if not os.path.exists(args.save_checkpoints_dir):
            os.makedirs(args.save_checkpoints_dir)

    transform_train = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(args.model_input_size), 
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
    transform_test = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(args.model_input_size),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
    train_dataset = datasets.cifar.CIFAR100(root='cifar100', train=True, transform=transform_train, download=True)
    test_dataset = datasets.cifar.CIFAR100(root='cifar100', train=False, transform=transform_test, download=True)

    # 给每个rank对应的进程分配训练的样本索引
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)

    # 将样本索引每batch_size个元素组成一个list
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)

    num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # 数据加载进程数量

    if rank == 0: # 如果是master node的一个进程
        print('Using  dataloader workers every process'.format(num_workers))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_sampler=train_batch_sampler,
                                               pin_memory=True,
                                               num_workers=num_workers)

    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=batch_size,
                                             sampler=val_sampler,
                                             pin_memory=True,
                                             num_workers=num_workers)

    # tensorboard 显示数据
    if rank == 0:
        data_batch, label_batch = next(iter(train_loader))
        images_grid = torchvision.utils.make_grid(tensor=data_batch, nrow=8, normalize=True, scale_each=True)
        summary_writer.add_image(tag='image', img_tensor=images_grid, global_step=0)

    # 模型
    model = models.resnet18(num_classes=args.num_classes)
    
    # tensorboard可视化模型计算图
    if rank == 0:
        dummy_input = torch.randn(1, 3, args.model_input_size, args.model_input_size)
        summary_writer.add_graph(model=model, input_to_model=dummy_input, verbose=False)

    # 如果存在预训练权重则载入
    if os.path.exists(weights_path):
        weights_dict = torch.load(weights_path, map_location=device)
        load_weights_dict = k: v for k, v in weights_dict.items()
                             if model.state_dict()[k].numel() == v.numel()
        model.load_state_dict(load_weights_dict, strict=False)
    else:
        checkpoint_path = os.path.join(tempfile.gettempdir(), "initial_weights.pt")
        # 分布式训练是模型并行的,需要保证加载到各个GPU上的模型具有相同的权重,
        # 因此如果不存在预训练权重,需要将第一个进程中的权重保存,
        # 然后其它进程载入,保持初始化权重一致
        if rank == 0:
            torch.save(model.state_dict(), checkpoint_path)

        dist.barrier() # 分布式进程同步

        # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)
    else:
        # 只有训练带有BN结构的网络时使用SyncBatchNorm才有意义
        # 但使用SyncBatchNorm后训练会更耗时
        if args.syncBN:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)

    # 包装为DDP模型
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    # 损失函数
    criterion = nn.CrossEntropyLoss()

    # optimizer
    train_param = [parm for parm in model.parameters() if parm.requires_grad] # 去掉冻结层参数,只获取网络可训练参数
    optimizer = optim.SGD(train_param, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # warmup与余弦退火学习率
    cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=args.warmup_lr_times, total_epoch=args.warmup_epoch, after_scheduler=cosine_scheduler)
    
    best_acc = 0.0

    # 断点续训
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=checkpoint['epoch'])
        lr_scheduler.after_scheduler = cosine_scheduler
        args.start_epoch = checkpoint['epoch'] + 1

    for epoch in range(args.start_epoch, args.epochs):
        # 在分布式模式下,需要在每个epoch开始时调用set_epoch()方法,
        # 然后再创建DataLoader迭代器,以使shuffle操作能够在多个epoch中正常工作。 
        # 否则,dataloader迭代器产生的数据将始终使用相同的顺序
        train_sampler.set_epoch(epoch)

        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    loss_function=criterion,
                                    device=device,
                                    epoch=epoch,
                                    args=args)

        lr_scheduler.step()
        lr = optimizer.param_groups[0]['lr']

        sum_num = evaluate(model=model, data_loader=val_loader, device=device)
        val_accuracy = sum_num / val_sampler.total_size

        if rank == 0:
            print(" Epoch[:0>3/:0>3] LR::.6 Validation Acc::.2%".format(
                  datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, lr, val_accuracy))
            summary_writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            summary_writer.add_scalar("Train Loss", mean_loss, epoch)
            summary_writer.add_scalar("Learing Rate", lr, epoch)
            for name, param in model.named_parameters():
                summary_writer.add_histogram(name + '_grad', param.grad, epoch)
                summary_writer.add_histogram(name + '_data', param, epoch)
            
            # 保存checkpoint
            torch.save(
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'accuracy': round(val_accuracy, 3),
                    'args': args,
                , "/model__.pth".format(args.save_checkpoints_dir, epoch, round(val_accuracy, 3)))

            if val_accuracy > best_acc:
                best_acc = val_accuracy
                torch.save(
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'accuracy': round(val_accuracy, 3),
                    'args': args,
                , "/best.pth".format(args.save_checkpoints_dir))
    
    # 删除临时缓存文件
    if rank == 0:
        if os.path.exists(checkpoint_path) is True:
            os.remove(checkpoint_path)

    dist.destroy_process_group() # 是否所有的进程组资源

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-classes', type=int, default=100)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight-decay', type=float, default=0.005)
    parser.add_argument('--warmup-lr-times', type=int, default=1)
    parser.add_argument('--warmup-epoch', type=int, default=10)
    parser.add_argument('--save-checkpoints-dir', type=str, default='checkpoints')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', type=int, default=1, help='start epoch')
    parser.add_argument('--syncBN', type=bool, default=True)
    parser.add_argument('--pretrained-weights', type=str, default='resNet18.pth', help='initial weights path')
    parser.add_argument('--model-input-size', type=int, default=224)
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--world-size', default=4, type=int, help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    opt = parser.parse_args()

    main(opt)

如果在分布式训练中只有一台Master Node,4张显卡,可通过如下命令启动分布式训练:

python -m torch.distributed.launch  --nproc_per_node=4 --use_env dist_classification_train.py

以上是关于Pytorch分布式训练与断点续训的主要内容,如果未能解决你的问题,请参考以下文章

pytorch实现断点续训

PyTorch保存模型断点以及加载断点继续训练

基于pytorch实现简单的分类模型训练

tensorflow的断点续训

断点续训

第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像