Swin Transformer实战: timm使用MixupCutout和评分一网打尽,图像分类任务
Posted AI浩
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Swin Transformer实战: timm使用MixupCutout和评分一网打尽,图像分类任务相关的知识,希望对你有一定的参考价值。
文章目录
摘要
本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演示如何使用timm版本的Swin Transformer图像分类模型实现分类任务已经对验证集得分的统计,本文实现了多个GPU并行训练。
通过本文你和学到:
1、如何从timm调用模型、loss和Mixup?
2、如何制作ImageNet数据集?
3、如何使用Cutout数据增强?
4、如何使用Mixup数据增强。
5、如何实现多个GPU训练和验证。
6、如何使用余弦退火调整学习率?
7、如何使用classification_report实现对模型的评价。
8、预测的两种写法。
Swin Transformer简介
目标检测刷到58.7 AP!
实例分割刷到51.1 Mask AP!
语义分割在ADE20K上刷到53.5 mIoU!
今年,微软亚洲研究院的Swin Transformer又开启了吊打CNN的模式,在速度和精度上都有很大的提高。这篇文章带你实现Swin Transformer图像分类。
资料汇总
论文: https://arxiv.org/abs/2103.14030
代码: https://github.com/microsoft/Swin-Transformer
论文翻译:https://wanghao.blog.csdn.net/article/details/120724040
一些大佬的B站视频:
1、霹雳吧啦Wz:https://www.bilibili.com/video/BV1yg411K7Yc?from=search&seid=18074716460851088132&spm_id_from=333.337.0.0
2、ClimbingVision社区:震惊!这个关于Swin Transformer的论文分享讲得太透彻了!_哔哩哔哩_bilibili
关于Swin Transformer的资料有很多,在这里就不一一列举了,我觉得理解这个模型的最好方式:源码+论文。
数据增强Cutout和Mixup
为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:
pip install torchtoolbox
Cutout实现,在transforms中。
from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
需要导入包:from timm.data.mixup import Mixup,
定义Mixup,和SoftTargetCrossEntropy
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=12)
criterion_train = SoftTargetCrossEntropy()
项目结构
Swin_demo
├─data
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─mean_std.py
├─makedata.py
├─train.py
├─test1.py
└─test.py
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
计算mean和std
为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
def get_mean_and_std(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
print(get_mean_and_std(train_dataset))
数据集结构:
运行结果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
把这个结果记录下来,后面要用!
生成数据集
我们整理还的图像分类的数据集结构是这样的
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet
pytorch和keras默认加载方式是ImageNet数据集格式,格式是
├─data
│ ├─val
│ │ ├─Black-grass
│ │ ├─Charlock
│ │ ├─Cleavers
│ │ ├─Common Chickweed
│ │ ├─Common wheat
│ │ ├─Fat Hen
│ │ ├─Loose Silky-bent
│ │ ├─Maize
│ │ ├─Scentless Mayweed
│ │ ├─Shepherds Purse
│ │ ├─Small-flowered Cranesbill
│ │ └─Sugar beet
│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
新增格式转化脚本makedata.py,插入代码:
import glob
import os
import shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
print('true')
#os.rmdir(file_dir)
shutil.rmtree(file_dir)#删除再建立
os.makedirs(file_dir)
else:
os.makedirs(file_dir)
from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
file_class=file.replace("\\\\","/").split('/')[-2]
file_name=file.replace("\\\\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
for file in val_files:
file_class=file.replace("\\\\","/").split('/')[-2]
file_name=file.replace("\\\\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
训练
完成上面的步骤后,就开始train脚本的编写,新建train.py.
导入项目使用的库
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.models import swin_small_patch4_window7_224
from torchtoolbox.transform import Cutout
设置全局参数
设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使用CPU。建议使用GPU,CPU太慢了。
# 设置全局参数
model_lr = 1e-4
BATCH_SIZE = 4
EPOCHS = 1000
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
图像预处理与增强
数据处理比较简单,加入了Cutout、做了Resize和归一化,定义Mixup函数。
# 数据预处理7
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])
])
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=12)
读取数据
使用pytorch默认读取数据的方式,然后将dataset_train.class_to_idx打印出来,预测的时候要用到。
# 读取数据
dataset_train = datasets.ImageFolder('data/train', transform=transform)
dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
print(dataset_train.class_to_idx)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
class_to_idx的结果:
‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3, ‘Common wheat’: 4, ‘Fat Hen’: 5, ‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8, ‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10, ‘Sugar beet’: 11
设置模型
- 设置loss函数,train的loss为:SoftTargetCrossEntropy,val的loss:nn.CrossEntropyLoss()。
- 设置模型为swin_small_patch4_window7_224,预训练设置为true,num_classes设置为12。
- 检测可用显卡的数量,如果大于1,则要用torch.nn.DataParallel加载模型,开启多卡训练。
- 优化器设置为adam。
- 学习率调整策略选择为余弦退火。
# 实例化模型并且移动到GPU
criterion_train = SoftTargetCrossEntropy()
criterion_val = torch.nn.CrossEntropyLoss()
model_ft = swin_small_patch4_window7_224(pretrained=True)
print(model_ft)
num_ftrs = model_ft.head.in_features
model_ft.head = nn.Linear(num_ftrs, 12)
model_ft.to(DEVICE)
print(model_ft)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model_ft = torch.nn.DataParallel(model_ft)
print(model_ft)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=model_lr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
定义训练和验证函数
定义训练函数和验证函数,在一个epoch完成后,使用classification_report计算详细的得分情况。
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
samples, targets = mixup_fn(data, target)
optimizer.zero_grad()
output = model(data)
loss = criterion_train(output, targets)
loss.backward()
optimizer.step()
lr = optimizer.state_dict()['param_groups'][0]['lr']
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: [/ (:.0f%)]\\tLoss: :.6f\\tLR::.9f'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
ave_loss = sum_loss / len(train_loader)
print('epoch:,loss:'.format(epoch, ave_loss))
ACC = 0
# 验证过程
def val(model, device, test_loader):
global ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
val_list = []
pred_list = []
with torch.no_grad():
for data, target in test_loader:
for t in target:
val_list.append(t.data.item())
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion_val(output, target)
_, pred = torch.max(output.data, 1)
for p in pred:
pred_list.append(p.data.item())
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
print('\\nVal set: Average loss: :.4f, Accuracy: / (:.0f%)\\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
if acc > ACC:
torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
ACC = acc
return val_list, pred_list
# 训练
for epoch in range(1, EPOCHS + 1):
train(model_ft, DEVICE, train_loader, optimizer, epoch)
cosine_schedule.step()
val_list, pred_list = val(model_ft, DEVICE, test_loader)
print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
运行结果:
测试
我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:
测试集存放的目录如下图:
第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!
第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。
第三步 加载model,并将模型放在DEVICE里,
第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。
import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
'Common wheat','Fat Hen', 'Loose Silky-bent',
'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])
])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)
path='data/test/'
testList=os.listdir(path)
for file in testList:
img=Image.open(path+file)
img=transform_test(img)
img.unsqueeze_(0)
img = Variable(img).to(DEVICE)
out=model(img)
# Predict
_, pred = torch.max(out.data, 1)
print('Image Name:,predict:'.format(file,classes[pred.data.item()]))
运行结果:
第二种 使用自定义的Dataset读取图片
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
'Common wheat','Fat Hen', 'Loose Silky-bent',
'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([
transforms.Resize((224,以上是关于Swin Transformer实战: timm使用MixupCutout和评分一网打尽,图像分类任务的主要内容,如果未能解决你的问题,请参考以下文章
Swin Transformer实战:使用 Swin Transformer实现图像分类。
Swin Transformer实战:使用 Swin Transformer实现图像分类。
Swin Transformer v2实战:使用Swin Transformer v2实现图像分类
Swin-Transformer 图像分割实战:使用Swin-Transformer-Semantic-Segmentation训练ADE20K数据集(语义分割)