第七节:CNN练习1使用四种网络结构(VGGResNetMobileNetInceptionNet)进行cifar10训练
Posted 快乐江湖
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第七节:CNN练习1使用四种网络结构(VGGResNetMobileNetInceptionNet)进行cifar10训练相关的知识,希望对你有一定的参考价值。
文章目录
一:cifar10数据集介绍
cifar10数据集:CIFAR-10数据集是8000万微小图片的标签子集
数据集由6万张32*32的彩色图片组成,一共有10个类别。每个类别6000张图片。其中有5万张训练图片及1万张测试图片。使用torchvision.datasets.CIFAR10
可进行下载
train_dataset = datasets.CIFAR10(root=parametes.data_path, train=True,
transform=transforms.ToTensor(), download=False)
test_dataset = datasets.CIFAR10(root=parametes.data_path, train=False,
transform=transforms.ToTensor(), download=False)
如下,下载后会生成5个训练块文件和1个测试块文件,每一个块文件10000张图片
这种文件并非图像文件,为了后续更好的训练,且能查看到训练过程中的图像变化,所以我们需要把它们转换为图像文件,转换代码如下
import pickle
import numpy as np
import glob
import os
import cv2
# 提取函数
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 类别名字
label_name = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# 使用train_list和test_list拿到对应文件名字
train_list = glob.glob('./cifar-10-batches-py/data_batch_*')
# test_list = glob.glob('./cifar-10-batches-py/test_batch')
# 保存路径
save_path = './train/'
# 遍历
for l in train_list:
print(l)
l_dict = unpickle(l)
"""
映射为字典,有4个key
batch_label:该图片属于哪一个batch
labels:所属类别
data:图像数组
filenames:文件名
"""
print(l_dict.keys())
for im_idx, im_data in enumerate(l_dict[b'data']):
# 获取类别和名字
im_label = l_dict[b'labels'][im_idx]
im_name = l_dict[b'filenames'][im_idx]
# print(im_label, im_name, im_data)
# 映射为英文名
im_label_name = label_name[im_label]
# 将此一维数组转为三维并交换维度
im_data = np.reshape(im_data, [3, 32, 32])
im_data = np.transpose(im_data, (1, 2, 0))
# cv2.imshow("im_data", cv2.resize(im_data, (200, 200)))
# cv2.waitKey(0)
# 每个文件夹下创建对应类别文件夹,相同类别图片写入相同文件夹
if not os.path.exists(os.path.join(save_path, im_label_name)):
os.mkdir(os.path.join(save_path, im_label_name))
cv2.imwrite(os.path.join(save_path, im_label_name, im_name.decode('utf-8')), im_data)
转换后文件结构如下
二:代码
(1)数据加载脚本编写
import torchvision.datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob
# 类别名字
label_name = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# 类比名字映射索引
label_dict =
for idx, name in enumerate(label_name):
label_dict[name] = idx
def default_loader(path):
return Image.open(path).convert("RGB")
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
class MyDataset(Dataset):
"""
im_list:是一个列表,每一个元素是图片路径
transform:对图片进行增强
loader:使用PIL对图片进行加载
"""
def __init__(self, im_list, transform=None, loader=default_loader):
super(MyDataset, self).__init__()
# imgs为二维列表,每一个子列表中第一个元素存储im_list,第二个通过label_dict映射为索引
imgs = []
for im_item in im_list:
# 路径'./data/test/airplane/aeroplane_s_000002.png'中倒数第二个是标签名
im_label_name = im_item.split("\\\\")[-2]
imgs.append([im_item, label_dict[im_label_name]])
self.imgs = imgs
self.transform = transform
self.loader = loader
def __getitem__(self, index):
im__path, im_label = self.imgs[index]
# 会调用PIL加载图片数据
im_data = self.loader(im__path)
# 如果给了transoform那么就对图片进行增强
if self.transform is not None:
im_data = self.transform(im_data)
return im_data, im_label
def __len__(self):
return len(self.imgs)
im_train_list = glob.glob(r'./data/train/*/*.png')
im_test_list = glob.glob(r'./data/test/*/*.png')
train_dataset = MyDataset(im_train_list, transform=train_transforms)
test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())
if __name__ == '__main__':
im_train_list = glob.glob(r'./data/train/*/*.png')
im_test_list = glob.glob(r'./data/test/*/*.png')
train_dataset = MyDataset(im_train_list, transform=train_transforms)
test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())
print(len(train_dataset))
print(len(test_dataset))
train_loader = DataLoader(dataset=train_dataset, batch_size=6, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=False, num_workers=0)
"""
train_transforms = transforms.Compose([
transforms.RandomResizedCrop((28, 28)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)
print(train_dataset.classes[: 5])
print("-"*30)
print(train_dataset.class_to_idx)
print("-"*30)
print(train_dataset.imgs[: 5])
"""
(2)模型搭建
①:VGG
# 给定字典选择模型
cfgs =
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
#'vgg16': [16, 16, 'M', 32, 32, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
# 生成卷积层
def create_conv(cfg):
layers = []
in_chaneels = parametes.init_in_chaneels
# 遍历列表
for c in cfg:
# 如果遇到"M",则增加一个最大池化层,其kernel_size=2, stride=2
if c == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
# 如果是数字,则代表该卷积核输出,卷积核统一为3×3,填充为1
else:
Conv2d = nn.Conv2d(in_channels=in_chaneels, out_channels=c, kernel_size=3, padding=1)
layers += [Conv2d, nn.ReLU(True)]
# 下一个输入通道等于现在的输出通道
in_chaneels = c
return nn.Sequential(*layers)
# VGG16网络
class VGG16(nn.Module):
def __init__(self, conv, num_classes, init_weights=False):
super(VGG16, self).__init__()
self.conv = conv
self.fc = nn.Sequential(
# 图片输入为224×224的前提下
nn.Linear(512*7*7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, num_classes)
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
return x
# 参数初始化(KAIMING)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# 初始化网络
cfg = model.cfgs['vgg16']
net = model.VGG16(model.create_conv(cfg), parametes.num_classes, True)
net = net.to(parametes.device)
如下是VGG13,这种写法比较臃肿但清晰
import torch
import torch.nn as nn
import torch.nn.functional as F
class VGG13(nn.Module):
def __init__(self):
super(VGG13, self).__init__()
# N * 3 * 32 * 32
self.conv1_1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv1_2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)
# N * 64 * 16 * 16
self.conv2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)
# N * 128 * 8 * 8
self.conv3_1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv3_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2)
# N * 256 * 4 * 4
self.conv4_1 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv4_2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2)
# N * 512 * 2 * 2
self.conv5_1 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv5_2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.max_pooling5 = nn.MaxPool2d(kernel_size=2, stride=2)
# N * 512 * 1 * 1
# 全连接层
self.fc = nn.Sequential(
nn.Linear(512 * 1 * 1, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 10)
)
def forward(self, x):
out = self.conv1_1(x)
out = self.conv1_2(out)
out = self.max_pooling1(out)
out = self.conv2_1(out)
out = self.conv2_2(out)
out = self.max_pooling2(out)
out = self.conv3_1(out)
out = self.conv3_2(out)
out = self.max_pooling3(out)
out = self.conv4_1(out)
out = self.conv4_2(out)
out = self.max_pooling4(out)
out = self.conv5_1(out)
out = self.conv5_2(out)
out = self.max_pooling5(out)
out = torch.flatten(ou以上是关于第七节:CNN练习1使用四种网络结构(VGGResNetMobileNetInceptionNet)进行cifar10训练的主要内容,如果未能解决你的问题,请参考以下文章