SWA实战:使用SWA进行微调,提高模型的泛化

Posted AI浩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SWA实战:使用SWA进行微调,提高模型的泛化相关的知识,希望对你有一定的参考价值。

摘要

论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa

论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第 i i i个epoch的checkpoint为 w i w_i wi,一般情况下我们会选择训练过程中最后的一个epoch的模型 w n w_n wn或者在验证集上效果最好的一个模型 w i ∗ w^*_i wi作为最终模型。但SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值。

pytorch使用举例:

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。

详细实现过程

环境

pyotrch:1.10

准备

在开始今天的代码前,我们要准备好训练好的模型。然后才能开始今天的代码。

实现过程

定义模型,并将训练好的模型载入,代码如下:

    model_ft = efficientnet_b1(pretrained=True)
    print(model_ft)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Linear(num_ftrs, classes)
    model_ft.to(DEVICE)
    model_ft = torch.load(model_path)
    print(model_ft)
    fine_epoch = 80
    fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,
              use_amp)

定义模型为efficientnet_b1,这里要和训练的模型保持一致。

如果保存的整个模型,则使用torch.load(model_path)载入模型,如果只保存了权重信息,则要使用model_ft=load_state_dict(torch.load(model_path)),载入模型。

然后,设置fine的epoch为80。

接下来,我们一起去看fine_tune函数中的内容。

 # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
    if use_amp:
        model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”

定义优化器为SGD。

如果使用混合精度,则对amp初始化。

 # 随机权重平均SWA,实现更好的泛化
 swa_model = AveragedModel(model).to(device)
 # SWA调整学习率
 swa_scheduler = SWALR(optimizer, swa_lr=1e-6)

初始化SWA。

使用SWALR调整学习率。

接下来循环epoch,这里都是比较通用的逻辑。

 for epoch in range(1, epoch + 1):
        model.train()
        train_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            if len(data) % 2 != 0:
                print(len(data))
                data = data[0:len(data) - 1]
                target = target[0:len(target) - 1]
                print(len(data))
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            samples, targets = mixup_fn(data, target)
            output = model(samples)
            loss = train_criterion(output, targets)
            optimizer.zero_grad()
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print_loss = loss.data.item()
            train_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch:  [/ (:.0f%)]\\tLoss: :.6f\\tLR::.9f'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
        swa_model.update_parameters(model)
        swa_scheduler.step()

主要步骤有:

1、计算loss。

2、是否使用amp混合精度,如果使用混合精度则使用scaled_loss反向传播求梯度,否则直接loss反向传播求梯度。

3、 swa_model.update_parameters(model)更新swa_model的参数。

4、 swa_scheduler.step()更新学习率。

等待所有的epoch执行完成后。

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), "last.pt")

更新BN层参数。

然后保存模型的权重。注意:这里只能保存模型的权重,不能保存整个模型。

完成之后就可以测试了,执行代码:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.models.mobilenetv3 import mobilenet_v3_large
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
from timm.models.efficientnet import efficientnet_b1
import numpy as np

def show_outputs(output):

    output_sorted = sorted(output, reverse=True)
    top5_str = '-----TOP 5-----\\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = ': \\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\\n'
            top5_str += topi
    print(top5_str)

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = efficientnet_b1(pretrained=True)

num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 8)
swa_model = AveragedModel(model)
swa_model.load_state_dict(torch.load("last.pt"))
swa_model.to(DEVICE)
swa_model.eval()

path = 'test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = swa_model(img)
    out = out.data.cpu().numpy()[0]
    print(file)
    show_outputs(out)

这里测试代码和以前的写法没有啥区别,唯一不同的地方:

重新定义模型,然后载入权重。
运行结果:

完整代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85223146

以上是关于SWA实战:使用SWA进行微调,提高模型的泛化的主要内容,如果未能解决你的问题,请参考以下文章

DHCP中继器

MSTP&VRRP协议

MSTP&VRRP协议

MSTP&VRRP协议

SWA(Stochastic Weight Averaging)实验

如何防止 Apache PDFBox 中的通用签名伪造 (USF)、增量保存攻击 (ISA)、签名包装 (SWA)