从头开始训练神经网络(Unet)
Posted 小白抗小枪
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从头开始训练神经网络(Unet)相关的知识,希望对你有一定的参考价值。
前言:
如果将看论文看作是写代码中的学理论,那么写pytorch和训练5大项可能就相当于写代码。当然训练神经网络中还有很多的trick,来帮助我们更好的训练模型。但是今天我们只是说一下基础的几个步骤。并且我会说的很细,(手动狗头)长文警告~ 如果觉得写的可以的,麻烦给个赞哦谢谢~
训练所需环境:
对应自己GPU的pytorch版本。
数据集下载:https://github.com/dongwu92/AutoPortraitMatting
五大项:
1.训练集和验证集的制作以及导入训练集和验证集。(dataset以及dataloader的书写)
关键的一步! (甚至可以说写神经网络模型训练就是写dataset)
不同的数据集对应的不同的dataset哈。但是总体的思想是一样的。
对于dataset他的主要目的是分类数据集的作用。dataloader是pytorch官方自己写好的直接用就好了。
当拿到一个数据集最重要的不是写代码,我们需要的是观察数据集。
--PortraitDataset
--|train
而其中包括两类.png 一个是数字结尾的,一个是对应的mask,以matte结尾的。
--|test
而其中包括两类.png 一个是数字结尾的,一个是对应的mask,以matte结尾的。
这时候就可以写了~
dataset的三大项
这里主要实现3个东西
1.__init__() 初始化你的dataset和主要包括路径,transform等等
2.__len__() 返回你的数据长度
3.__getitem() 根据索引返回你的图片
(具体每一句的代码含义我会在代码块里详细介绍)
class PortraitDataset(Dataset):
def __init__(self, data_dir, transform=None, in_size = 224):
"""
:param data_dir: data的路径
:param transform: 数据预处理工作
:param in_size: 默认大小
:param label_path_list 是一个链表存储的path
"""
super(PortraitDataset, self).__init__()
self.data_dir = data_dir
self.transform = transform
self.label_path_list = list()
self.in_size = in_size
#获取mask的path
self._get_img_path()
#返回数据长度
def __len__(self):
return len(self.label_path_list)
#根据索引求出单张图片
#该getitem的好处将mask导入进去 就好了
#在输入的时候使用enumerate(train_loader) 将mask表示的后10位删掉加上png就是img的path
def __getitem__(self, item):
#img 是原图也就是RGB形式的 灰度图是分割图 是L形式的
#path_label是一个列表
path_label = self.label_path_list[item]
#社舍去list中的后10位 这里是一个很漂亮的代码后面我们在说
path_img = path_label[:-10] + ".png"
#将图像读出来 并且以rgb的形式 但是其是一个4通道的 hwca
img_pil = Image.open(path_img).convert('RGB')
img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
img_hwc = np.array(img_pil)
#转换维度
img_chw = img_hwc.transpose((2, 0, 1))
#转为灰度图
label_pil = Image.open(path_label).convert('L')
label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
label_hw = np.array(label_pil)
#给前面增加一个通道
label_chw = label_hw[np.newaxis, :, :]
label_hw[label_hw != 0] = 1
if self.transform is not None:
img_chw_tensor = torch.from_numpy(self.transform(img_chw.numpy())).float()
label_chw_tensor = torch.from_numpy(self.transform(label_chw.numpy())).float()
else:
img_chw_tensor = torch.from_numpy(img_chw).float()
label_chw_tensor = torch.from_numpy(label_chw).float()
return img_chw_tensor, label_chw_tensor
def _get_img_path(self):
#获取对应img的mask
#将所有的图片导入到file_list中
file_list = os.listdir(self.data_dir)
#filter(function,iterable) function -- 判断函数。 iterable -- 可迭代对象
file_list = list(filter(lambda x: x.endswith("_matte.png"), file_list))
#在path_list列表中添加每一个的路径
path_list = [os.path.join(self.data_dir, name) for name in file_list]
random.shuffle(path_list)
if len(path_list) == 0:
raise Exception("\\ndata_dir: is a empty dir! Please checkout your path to images!".format(self.data_dir))
self.label_path_list = path_list
这样一个dataset就写完了。因为这个数据集中他的mask和image没有分开保存。但是他的标号却是一一对应的,这样可以少些一个split image的操作。
#超参数的设置
LR = 0.01
BATCH_SIZE = 8
max_epoch = 1
start_epoch = 0
lr_step = 150
val_interval = 3
checkpoint_interval = 20
vis_num = 10
mask_thres = 0.5
train_dir = os.path.join("..", "PortraitDataset", "train")
valid_dir = os.path.join("..", "PortraitDataset", "test")
print(train_dir)
#step 1 训练集预测集的制作
train_set = PortraitDataset(train_dir)
valid_set = PortraitDataset(valid_dir)
#step 2 导入训练集和测试集
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)
2.写网络模型
网上有很多的对于不同网络的实现
Unet 本来想自己写,确实自己写了,但是很复杂。所以我们拿一个现成的Unet模型。
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
这个部分没什么讲的哈
在主函数里写入
net = UNet(in_channels=3, out_channels=1, init_features=32)
net.to(device)
3.定义损失函数
这里使用交叉熵损失函数
loss_fn = nn.MSELoss()
4.梯度下降(SGD)
这里也是用最基础的SGD随机梯度下降。常见的还有BGD,ADAM
#step 5梯度下降
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)
5.训练(两层for循环)
一层是为了epoch准备的。一层是为了一个epoch里的所有数据准备的。
1.首先两个for循环 里面的for循环获得我们输入图像,索引,标签图像.
2.前向传播
3.先清空梯度(因为pytorch 不会自动清空梯度 它会将梯度累加,所以每次反向传播前都需要清空梯度)
4.计算损失 加反向传播
5.更新参数(撒花完结~)
for epoch in range(start_epoch, max_epoch):
train_loss_total = 0.
train_dice_total = 0.
net.train()
#训练数组和标签 第一个值为序号,第二个值是输入数据,第三个值是数据标签
for iter, (inputs, labels) in enumerate(train_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
#forward
outputs = net(inputs)
#backward
optimizer.zero_grad()
loss = loss_fn(outputs, labels)
loss.backward()
#更新参数
optimizer.step()
#打印关于
train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
train_dice_curve.append(train_dice)
train_curve.append(loss.item())
#item()取出张量具体位置的元素元素值
train_loss_total += loss.item()
#print所有
print("Training:Epoch[:0>3/:0>3] Iteration[:0>3/:0>3] running_loss: :.4f, mean_loss: :.4f "
"running_dice: :.4f lr:".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
train_loss_total / (iter + 1), train_dice, scheduler.get_lr()))
#学习率 更新 每一个epoch都在更新参数
scheduler.step()
#每20轮 保存一下参数
if(epoch + 1) % checkpoint_interval == 0:
checkpoint = "model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch
path_checkpoint = "./checkpoint__epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
# validate the model
if (epoch + 1) % val_interval == 0:
net.eval()
valid_loss_total = 0.
valid_dice_total = 0.
with torch.no_grad():
for j, (inputs, labels) in enumerate(valid_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
valid_loss_total += loss.item()
valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
valid_dice_total += valid_dice
valid_loss_mean = valid_loss_total / len(valid_loader)
valid_dice_mean = valid_dice_total / len(valid_loader)
valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)
print("Valid:\\t Epoch[:0>3/:0>3] mean_loss: :.4f dice_mean: :.4f".format(
epoch, max_epoch, valid_loss_mean, valid_dice_mean))
这里为了可视化我们的神经网络或者在控制台打印我们自己想要看见的数值。
当然我们要在输出图像。看看loss之类的。这个后面我会补充
以上是关于从头开始训练神经网络(Unet)的主要内容,如果未能解决你的问题,请参考以下文章