Pytorch深度学习50篇·······第三篇:非监督学习
Posted 咕里个咚
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch深度学习50篇·······第三篇:非监督学习相关的知识,希望对你有一定的参考价值。
兄弟萌,我咕里个咚今天又杀回来了,有几天时间可以不用驻场了,喜大普奔,终于可以在有网的地方码代码了,最近驻场也是又热又心累啊,抓紧这几天,再更新一点的新东西。
今天主要讲一下非监督学习,你可能要问了,什么是非监督学习,我的理解就是不会给样本标签的,它本质上是一个统计手段,在没有标签的数据里可以发现潜在的一些结构的一种训练方式。这个可以用来干什么,举个例子,在工业场景瑕疵检测的运用中,由于良品的数量远远高于不良品的数量,如果这个时候你要采用监督学习,那么收集样本的时间就多得吓人了,可能你样本还没有收集完全,产品都已经做完下线了,所以,你就挠头吧。于是,非监督学习就迎来了一片蓝海,但是即使是蓝海也要你的船能开才行,这里面也不乏调整,比如非监督学习的效果就不太好去评估。当然,我想了点办法,在工业领域也有所应用了,但是过杀有点高,大概5%左右,这5%的过杀,再通过后期的监督算法,其实可以解决很多问题了,这样工业上的大部分问题,都可以有所缓解了。我真棒。哈哈哈
1.非监督学习网络架构
先提供一下,我的非监督学习的网络架构,还是基于pytorch来写的。我给这个网络一个名称叫做
piercing eye。话不多说,上代码。
from torch import nn
import torch
class CBP(nn.Module):
"""
conv + batchnormal + prelu
"""
def __init__(self,inc,ouc):
super().__init__()
self.block1=nn.Sequential(
nn.Conv2d(inc,ouc,3,1,1),
nn.BatchNorm2d(ouc),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class Up_Block(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.block1=nn.Sequential(
nn.ConvTranspose2d(in_channel, out_channel, 3, 2, 1, 1),
nn.BatchNorm2d(out_channel),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class Down_Block(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.block1=nn.Sequential(
nn.Conv2d(in_channel, out_channel, 5, 2, padding=2),
nn.BatchNorm2d(out_channel),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class PiercingEye(nn.Module):
def __init__(self):
super().__init__()
self.block1=nn.Sequential(
CBP(3, 4),
Down_Block(4, 8),
Down_Block(8, 16),
Down_Block(16, 32),
Down_Block(32, 64),
Down_Block(64, 128),
Down_Block(128, 256),
CBP(256, 256),
Up_Block(256, 128),
Up_Block(128, 64),
Up_Block(64, 32),
Up_Block(32, 16),
Up_Block(16, 8),
Up_Block(8, 4),
nn.Conv2d(4,3,1),
nn.Tanh()
)
def forward(self,y):
return self.block1(y)
if __name__ == '__main__':
net = PiercingEye()
x = torch.Tensor(2,3,512,512)
y = net(x)
print(y.shape)
简单的说明一下,其实就是做了6次下采样和6次上采样,也就是AE网络,中间没有任何跳跃连接,也可以理解成是一个生成网络。
2.数据集准备
我直接给大家一个百度云的链接,这也是一个开源的数据集,我稍微整理了一下,方便大家使用
链接:https://pan.baidu.com/s/1ir5xmYJWAX8QIXHb6_5zWw
提取码:5ph7
里面一共两个文件夹,data_train,data_val截图以示清白。
数据大概就是这个样子的,左边是ok的,右边是ng的,不是菊花,不是菊花,不是菊花,重要的事说三遍。
data_train里面一共784张图片,都是ok图片
data_val里面一共366张图片,150张ng图片和216张ok图片
我们训练只训练OK图片,看看能不能通过只训练OK图片来判断验证集里面的OK和NG
3.Dataset
有了网络,有了数据,就该来处理数据准备往网络里送了,上代码
import torch
import os
from torch.utils.data import Dataset
import random
import data_agumentation
import torchvision.transforms as tf
import cv2
transform = tf.Compose([tf.ToTensor(),tf.Normalize([0.5],[0.5])])
class train_data(Dataset):
def __init__(self, path):
print('start build_train_data')
self.path=path
self.imgs=[]
for i in os.listdir(path):
self.imgs.append(i)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
random_num = random.randint(0,2)
img=cv2.imread(self.path+'/'+self.imgs[index])
if random_num == 1:
img = data_agumentation.augment_left_flip(img)
elif random_num == 2:
img = data_agumentation.augment_rotate(img,180)
img = transform(img)
return img,img
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data=train_data(r'D:\\blog_project\\guligedong_unsupervised\\data\\data_train')
train_loader = torch.utils.data.DataLoader(data, batch_size=3, shuffle=True)
for index,(img, label) in enumerate(train_loader):
print(img.size())
print(label.size())
print()
你应该很熟悉,因为上一篇也写过类似的,这你会看到,其实img和label是一样的, 我们的目的也是输入一张图片,让他生成一张一样的图片,这样是为什么呢?我的思路就是,因为我的训练集只有ok的,网络只能生成ok的特征,如果输入的是ng的图片,那么网络就不能生成ng的特征,这个时候就会有差异了。
def __getitem__(self, index):
random_num = random.randint(0,2)
img=cv2.imread(self.path+'/'+self.imgs[index])
if random_num == 1:
img = data_agumentation.augment_left_flip(img)
elif random_num == 2:
img = data_agumentation.augment_rotate(img,180)
img = transform(img)
return img,img
在这个代码段里,我用了随机的样本增强,这个样本增强也是我前面文章中提供给大家的,就是一个水平的镜像翻转和180度的旋转,当然你还可以增加90和270度的旋转,大家也可以看到,这个dataset就简单了很多,因为非监督学习没有标签,或者说,非监督学习样本的标签就是它本身。
今天就先更新到这里,下一篇我即将要更新训练代码和测试代码,以及整个项目的代码,尽情期待。
顺便问一下原力值是个什么东西,可以当饭吃吗?
至此,敬礼,salute!!!!
以上是关于Pytorch深度学习50篇·······第三篇:非监督学习的主要内容,如果未能解决你的问题,请参考以下文章
(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解
Pytorch深度学习50篇·······第一篇:认识深度学习
Pytorch深度学习50篇·······第六篇:常见损失函数篇-----BCELoss及其变种