基于pytorch实现简单的分类模型训练

Posted 洪流之源

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于pytorch实现简单的分类模型训练相关的知识,希望对你有一定的参考价值。

基本功能如下:

  1. 支持分布式训练;

  1. 支持余弦退火学习率与warmup调整策略;

  1. 支持断点续训;

  1. 支持训练日志保存;

  1. 支持标签平滑策略;

  1. 支持层冻结;

  1. 支持torchvision.models模型的训练;

  1. 支持tensorboard可视化学习率、训练损失、验证准去率;

  1. 支持tensorboard可视化模型结构;

  1. 支持tensorboard可视化训练数据;

  1. 支持tensorboard可视化模型权重;

代码如下:

import os
import argparse
import numpy as np
import random
import tempfile
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from warmup_scheduler import GradualWarmupScheduler
from datetime import datetime
import sys
import logging
import logging.handlers
import time


def init_log(logfile):
    logger = logging.getLogger('mylog')

    logger.setLevel(logging.INFO)
    logger.propagate = False

    hdlr = logging.handlers.TimedRotatingFileHandler(logfile, 'D', 1, 10)

    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)

    logger.addHandler(hdlr)

    return logger

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, eps=0.1, reduction='mean'):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.eps = eps
        self.reduction = reduction

    def forward(self, output, target):
        c = output.size()[-1]
        log_preds = F.log_softmax(output, dim=-1)
        if self.reduction=='sum':
            loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=-1)
            if self.reduction=='mean':
                loss = loss.mean()
        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)

def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"]) # RANK表示global rank, 表示当前进程在所有进程中的序号
        args.world_size = int(os.environ['WORLD_SIZE']) # WORLD_SIZE表示分布式训练的所有的进程数量
        args.gpu = int(os.environ['LOCAL_RANK']) # 在当前分布式节点中进程的编号,因为一个显卡分配一个进程,因此这个编号也是当前进程运行的GPU编号
    elif 'SLURM_PROCID' in os.environ: # 在SLURM集群上使用
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return
 
    args.distributed = True
 
    torch.cuda.set_device(args.gpu) # 指定当前进程所使用的的GPU
    args.dist_backend = 'nccl'  # 通信后端,nvidia GPU推荐使用NCCL
    print('| distributed init (rank ): '.format(args.rank, args.dist_url), flush=True)
    
    # 初始化分布式进程组
    dist.init_process_group(backend=args.dist_backend, 
                            init_method=args.dist_url, # 指的是如何初始化,以完成刚开始的进程同步,'file://' or 'tcp://' or 'env://',默认采用环境变量即'env://'
                            world_size=args.world_size, 
                            rank=args.rank)
    dist.barrier() # 同步所有的进程, 也就是所有节点的所有GPU运行至这个函数的时候, 才会执行后面的代码

def is_dist_avail_and_initialized():
    """检查是否支持分布式环境"""
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()

def is_main_process():
    return get_rank() == 0

def reduce_value(value, average=True):
    world_size = get_world_size()
    if world_size < 2:  # 单GPU的情况
        return value
 
    with torch.no_grad():
        dist.all_reduce(value) # 所有参与训练GPU设备的value值的总和
        if average:
            value /= world_size # 对value值求均值,得到每个GPU上的的平均value
 
        return value

def train_one_epoch(model, optimizer, data_loader, loss_function, device, epoch, args, logger):
    model.train()
    optimizer.zero_grad()
    mean_loss = torch.zeros(1).to(device)
 
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)
 
    for step, data in enumerate(data_loader):
        images, labels = data
        outputs = model(images.to(device))
 
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        loss = reduce_value(loss, average=True) # 得到多个GPU计算loss的均值
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses
 
        # 如果loss是非法值,则退出
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            logger.error('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)
 
        optimizer.step()
        optimizer.zero_grad()
 
        # 如果是主进程打印平均loss
        if is_main_process():
            data_loader.desc = " Training:Epoch[:0>3/:0>3] Iteration[:0>3/:0>3] Loss: :.4f".format(
                                datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, step + 1, len(data_loader), 
                                round(loss.item(), 3))
            logger.info(" Training:Epoch[:0>3/:0>3] Iteration[:0>3/:0>3] Loss: :.4f".format(
                                datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, step + 1, len(data_loader), 
                                round(loss.item(), 3)))
    
    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)
 
    return mean_loss.item()

@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
 
    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)
 
    # 如果是主进程打印验证进度
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)
        
    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()
        data_loader.desc = " Validating:Iteration[:0>3/:0>3]".format(
                           datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 
                           step + 1, len(data_loader))
 
    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)
 
    sum_num = reduce_value(sum_num, average=False) # 求和得到在所有GPU上正确样本个数
 
    return sum_num.item()

def main(args):
    if not torch.cuda.is_available():
        raise EnvironmentError("not find GPU device for training.")
    
    # 初始化分布式进程环境
    init_distributed_mode(args=args)

    logger = None

    rank = args.rank
 
    device = torch.device(args.device)
 
    # 相同的随机种子seed将模型在初始化过程中所用到的“随机数”全部固定下来,
    # 以保证每次重新训练模型需要初始化模型参数的时候能够得到相同的初始化参数,
    # 从而达到稳定复现训练结果的目的
    seed = args.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    batch_size = args.batch_size
    
    # 模型
    model = None
    try:
        model_name = 'torchvision.models.' + args.classification_model
        if os.path.exists(args.pretrained_weights):
            model = eval(model_name)(pretrained=False)
        else:
            model = eval(model_name)(pretrained=True)
        fc_features = model.fc.in_features
        model.fc = nn.Linear(fc_features, args.num_classes)
    except Exception as e:
        print(e)
        exit(-1)
    
    if rank == 0:  # 如果是master node的一个进程
        if not os.path.exists(args.save_checkpoints_dir):
            os.makedirs(args.save_checkpoints_dir)

        date_time = time.strftime('%Y_%m_%d_%H_%M_%S')
        log_dir = os.path.join(args.save_checkpoints_dir, date_time)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        log_file = os.path.join(log_dir, 'train.log')
        logger = init_log(log_file)

        args.save_checkpoints_dir = log_dir

        print_args = 'train config param: \\n'
        print('\\ntrain config param:')
        for param_key in args.__dict__:
            param = str(param_key) + ": " + str(args.__dict__[param_key])
            print(param)
            print_args = print_args + param + '\\n'

        print('\\n')
        print('Start Tensorboard with "tensorboard --logdir=", view at http://localhost:6006/'.format(log_dir))
        logger.info(print_args)
        logger.info('Start Tensorboard with "tensorboard --logdir=", view at http://localhost:6006/'.format(log_dir))
        summary_writer = SummaryWriter(log_dir=log_dir)
        
 
    transform_train = transforms.Compose([transforms.Resize([args.model_input_size, args.model_input_size]),
                                          transforms.RandomVerticalFlip(),
                                          transforms.RandomHorizontalFlip(),
                                          # transforms.CenterCrop(args.model_input_size), 
                                          transforms.RandomRotation(15),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=args.mean, std=args.std)])
 
    transform_val = transforms.Compose([transforms.Resize([args.model_input_size, args.model_input_size]),
                                        # transforms.CenterCrop(args.model_input_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=args.mean, std=args.std)])
    

    train_dataset = datasets.ImageFolder(root=args.train_dataset, transform=transform_train)
    val_dataset = datasets.ImageFolder(root=args.val_dataset, transform=transform_val)
 
    # 给每个rank对应的进程分配训练的样本索引
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
 
    # 将样本索引每batch_size个元素组成一个list
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)

    num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # 数据加载进程数量
 
    if rank == 0: # 如果是master node的一个进程
        logger.info('Using  dataloader workers every process'.format(num_workers))
 
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_sampler=train_batch_sampler,
                                               pin_memory=True,
                                               num_workers=num_workers)
 
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             sampler=val_sampler,
                                             pin_memory=True,
                                             num_workers=num_workers)
 
    # tensorboard 显示数据
    if rank == 0:
        data_batch, _ = next(iter(train_loader))
        images_grid = torchvision.utils.make_grid(tensor=data_batch, nrow=8, normalize=True, scale_each=True)
        summary_writer.add_image(tag='image', img_tensor=images_grid, global_step=0)
 
    # tensorboard可视化模型计算图
    if rank == 0:
        dummy_input = torch.randn(1, 3, args.model_input_size, args.model_input_size)
        summary_writer.add_graph(model=model, input_to_model=dummy_input, verbose=False)
    
    checkpoint_path = os.path.join(tempfile.gettempdir(), "initial_weights.pt")

    # 如果存在预训练权重则载入
    if os.path.exists(args.pretrained_weights):
        weights_dict = torch.load(args.pretrained_weights, map_location=device)
        load_weights_dict = k: v for k, v in weights_dict.items()
                             if model.state_dict()[k].numel() == v.numel()
        model.load_state_dict(load_weights_dict, strict=False)
    else:
        
        # 分布式训练是模型并行的,需要保证加载到各个GPU上的模型具有相同的权重,
        # 因此如果不存在预训练权重,需要将第一个进程中的权重保存,
        # 然后其它进程载入,保持初始化权重一致
        if rank == 0:
            torch.save(model.state_dict(), checkpoint_path)
 
        dist.barrier() # 分布式进程同步
 
        # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)
    else:
        # 只有训练带有BN结构的网络时使用SyncBatchNorm才有意义
        # 但使用SyncBatchNorm后训练会更耗时
        if args.syncBN:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
 
    # 包装为DDP模型
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
 
    # 损失函数
    if args.label_smoothing:
        criterion = LabelSmoothingCrossEntropy()
    else:
        criterion = nn.CrossEntropyLoss()
    

    # optimizer
    train_param = [parm for parm in model.parameters() if parm.requires_grad] # 去掉冻结层参数,只获取网络可训练参数
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(train_param, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(train_param, lr=args.lr)
 
    # warmup与余弦退火学习率
    cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=args.warmup_lr_times, total_epoch=args.warmup_epoch, after_scheduler=cosine_scheduler)
    
    best_acc = 0.0
 
    # 断点续训
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=checkpoint['epoch'])
        lr_scheduler.after_scheduler = cosine_scheduler
        args.start_epoch = checkpoint['epoch'] + 1
 
    for epoch in range(args.start_epoch, args.epochs):
        # 在分布式模式下,需要在每个epoch开始时调用set_epoch()方法,
        # 然后再创建DataLoader迭代器,以使shuffle操作能够在多个epoch中正常工作。 
        # 否则,dataloader迭代器产生的数据将始终使用相同的顺序
        train_sampler.set_epoch(epoch)
 
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    loss_function=criterion,
                                    device=device,
                                    epoch=epoch,
                                    args=args,
                                    logger=logger)
 
        lr_scheduler.step()
        lr = optimizer.param_groups[0]['lr']
 
        sum_num = evaluate(model=model, data_loader=val_loader, device=device)
        val_accuracy = sum_num / val_sampler.total_size
 
        if rank == 0:
            logger.info(" Epoch[:0>3/:0>3] LR::.6 Validation Acc::.2%".format(
                  datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, lr, val_accuracy))
            print(" Epoch[:0>3/:0>3] LR::.6 Validation Acc::.2%".format(
                  datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch, args.epochs, lr, val_accuracy))
            summary_writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            summary_writer.add_scalar("Train Loss", mean_loss, epoch)
            summary_writer.add_scalar("Learing Rate", lr, epoch)
            for name, param in model.named_parameters():
                summary_writer.add_histogram(name + '_grad', param.grad, epoch)
                summary_writer.add_histogram(name + '_data', param, epoch)
            
            if epoch % args.save_interval == 0:
                # 保存checkpoint
                torch.save(
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'accuracy': round(val_accuracy, 3),
                        'args': args,
                    , "/model__.pth".format(args.save_checkpoints_dir, epoch, round(val_accuracy, 3)))

            # 更新最佳结果
            if val_accuracy > best_acc:
                best_acc = val_accuracy
                torch.save(
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'accuracy': round(val_accuracy, 3),
                    'args': args,
                , "/best.pth".format(args.save_checkpoints_dir))
    
    # 删除临时缓存文件
    if rank == 0:
        if os.path.exists(checkpoint_path) is True:
            os.remove(checkpoint_path)
 
    dist.destroy_process_group() # 是否所有的进程组资源

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--classification-model', type=str, default='resnet18', 
                        help='should be torchvision.models support classification model, \\
                        eg: resnet18 resnet34 resnet50 resnet101 ...')
    parser.add_argument('--train-dataset', type=str, default='datasets/train/')
    parser.add_argument('--val-dataset', type=str, default='datasets/val/')
    parser.add_argument('--mean', type=list, default=[0.485, 0.456, 0.406])
    parser.add_argument('--std', type=list, default=[0.229, 0.224, 0.225])
    parser.add_argument('--num-classes', type=int, default=3)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=512)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight-decay', type=float, default=0.005)
    parser.add_argument('--warmup-lr-times', type=int, default=1)
    parser.add_argument('--warmup-epoch', type=int, default=10)
    parser.add_argument('--save-checkpoints-dir', type=str, default='checkpoints')
    parser.add_argument('--save-interval', type=int, default=1)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', type=int, default=1, help='start epoch')
    parser.add_argument('--syncBN', type=bool, default=True)
    parser.add_argument('--pretrained-weights', type=str, default='', 
                        help='initial weights path, If not specified, defualt auto download form https://download.pytorch.org/models')
    parser.add_argument('--model-input-size', type=int, default=128)
    parser.add_argument('--label-smoothing', type=bool, default=True)
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--world-size', default=4, type=int, help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    opt = parser.parse_args()

    main(opt)

训练脚本:

#!/bin/bash

export CUDA_DEVICE_ORDER='PCI_BUS_ID'
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

python -m torch.distributed.launch  --nproc_per_node=8 --use_env dist_classification_train.py

以上是关于基于pytorch实现简单的分类模型训练的主要内容,如果未能解决你的问题,请参考以下文章

基于pytorch框架实现手写图片的分类

基于pytorch平台分类模型训练调式心得

嘿~全流程带你基于Pytorch手撸图片分类“框架“--HuClassify

基于预训练模型的Unet超级简单懒人版Pytorch版

基于预训练模型的Unet超级简单懒人版Pytorch版

基于Pytorch实现的声音分类