小白学习PyTorch教程十四迁移学习:微调ResNet实现男人和女人图像分类
Posted 刘润森!
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了小白学习PyTorch教程十四迁移学习:微调ResNet实现男人和女人图像分类相关的知识,希望对你有一定的参考价值。
@Author:Runsen
上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类。
ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算机视觉任务。
- ResNet论文参见此处:
https://arxiv.org/abs/1512.03385
该模型是 2015 年 ImageNet 挑战赛的获胜者。 ResNet 的根本性突破是它使我们能够成功训练 150 层以上的极深神经网络。
下面是resnet18的整个网络结构:
Resnet 18 是在 ImageNet 数据集上预训练的图像分类模型。
这次使用Resnet 18 实现分类性别数据集,
该性别分类数据集共有58,658 张图像。(train:47,009 / val:11,649)
- Dataset: Kaggle Gender Classification Dataset
加载数据集
设置图像目录路径并初始化 PyTorch 数据加载器。和之前一样的模板套路
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device object
transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # normalization
])
transforms_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
data_dir = './gender_classification_dataset'
train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Training'), transforms_train)
val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Validation'), transforms_val)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=16, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)
print('Train dataset size:', len(train_datasets))
print('Validation dataset size:', len(val_datasets))
class_names = train_datasets.classes
print('Class names:', class_names)
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 60
plt.rcParams.update({'font.size': 20})
def imshow(input, title):
# torch.Tensor => numpy
input = input.numpy().transpose((1, 2, 0))
# undo image normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
input = std * input + mean
input = np.clip(input, 0, 1)
# display images
plt.imshow(input)
plt.title(title)
plt.show()
# load a batch of train image
iterator = iter(train_dataloader)
# visualize a batch of train image
inputs, classes = next(iterator)
out = torchvision.utils.make_grid(inputs[:4])
imshow(out, title=[class_names[x] for x in classes[:4]])
定义模型
我们使用迁移学习方法,只需要修改最后的输出即可。
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # binary classification (num_of_class == 2)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
训练阶段
由于ResNet18网络非常复杂,深,这里只训练num_epochs = 3
num_epochs = 3
start_time = time.time()
for epoch in range(num_epochs):
""" Training """
model.train()
running_loss = 0.
running_corrects = 0
# load a batch data of images
for i, (inputs, labels) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# get loss value and update the network weights
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_datasets)
epoch_acc = running_corrects / len(train_datasets) * 100.
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
""" Validation"""
model.eval()
with torch.no_grad():
running_loss = 0.
running_corrects = 0
for inputs, labels in val_dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(val_datasets)
epoch_acc = running_corrects / len(val_datasets) * 100.
print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
保存训练好的模型文件
save_path = 'face_gender_classification_transfer_learning_with_ResNet18.pth'
torch.save(model.state_dict(), save_path)
训练好的模型文件加载
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
model.load_state_dict(torch.load(save_path))
model.to(device)
model.eval()
start_time = time.time()
with torch.no_grad():
running_loss = 0.
running_corrects = 0
for i, (inputs, labels) in enumerate(val_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if i == 0:
print('[Prediction Result Examples]')
images = torchvision.utils.make_grid(inputs[:4])
imshow(images.cpu(), title=[class_names[x] for x in labels[:4]])
images = torchvision.utils.make_grid(inputs[4:8])
imshow(images.cpu(), title=[class_names[x] for x in labels[4:8]])
epoch_loss = running_loss / len(val_datasets)
epoch_acc = running_corrects / len(val_datasets) * 100.
print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
在最后的测试结果中,ACC达到了97,但是模型太复杂,运行太慢了,在项目中往往不可取。
以上是关于小白学习PyTorch教程十四迁移学习:微调ResNet实现男人和女人图像分类的主要内容,如果未能解决你的问题,请参考以下文章
小白学习PyTorch教程十二迁移学习:微调VGG19实现图像分类
小白学习PyTorch教程十六在多标签分类任务上 微调BERT模型
小白学习PyTorch教程十六在多标签分类任务上 微调BERT模型