Unet 语义分割模型(Keras)| 以细胞图像为例
Posted __不想写代码__
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Unet 语义分割模型(Keras)| 以细胞图像为例相关的知识,希望对你有一定的参考价值。
前言
最近由于在寻找方向上迷失自我,准备了解更多的计算机视觉任务重的模型。看到语义分割任务重Unet一个有意思的模型,我准备来复现一下它。
一、什么是语义分割
语义分割任务,如下图所示:
简而言之,语义分割任务就是将图片中的不同类别,用不同的颜色标记出来,每一个类别使用一种颜色。常用于医学图像,卫星图像任务。
那如何做到将像素点上色呢?
其实语义分割的输出和图像分类网络类似,图像分类类别数是一个一维的one hot 矩阵。例如:三分类的[0,1,0]。
语义分割任务最后的输出特征图 是一个三维结构,大小与原图类似,通道数就是类别数。 如下图(图片来源于知乎)所示:
其中通道数是类别数,每个通道所标记的像素点,是该类别在图像中的位置,最后通过argmax 取每个通道有用像素 合成一张图像,用不同颜色表示其类别位置。 语义分割任务其实也是分类任务中的一种,他不过是对每一个像素点进行细分,找到每一个像素点所述的类别。 这就是语义分割任务啦~
下面我们来复现 unet 模型
二、Unet
1.基本原理
什么是Unet,它的网络结构如下图所示:
整个网络是一个“U” 的形状,Unet 网络可以分成两部分,上图红色方框中是特征提取部分,和其他卷积神经网络一样,都是通过堆叠卷积提取图像特征,通过池化来压缩特征图。蓝色方框中为图像还原部分(这样称它可能不太专业,大家理解就好),通过上采样和卷积来来将压缩的图像进行还原。特征提取部分可以使用优秀的网络,例如:Resnet50,VGG等。
注意:由于 Resnet50和VGG 网络太大。本文将使用Mobilenet 作为主干特征提取网络。为了方便理解Unet,本文将使用自己搭建的一个mini_unet 去帮祝大家理解。为了方便计算,复现过程会把压缩后的特征图上采样和输入的特征图一样大小。
代码github地址: 一直上不去
先上传到码云: https://gitee.com/Boss-Jian/unet
2.mini_unet
mini_unet 是搭建来帮助大家理解语义分割的网络流程,并不能作为一个优秀的模型完成语义分割任务,来看一下代码的实现:
from keras.layers import Input,Conv2D,Dropout,MaxPooling2D,Concatenate,UpSampling2D
from numpy import pad
from keras.models import Model
def unet_mini(n_classes=21,input_shape=(224,224,3)):
img_input = Input(shape=input_shape)
#------------------------------------------------------
# #encoder 部分
#224,224,3 - > 112,112,32
conv1 = Conv2D(32,(3,3),activation='relu',padding='same')(img_input)
conv1 = Dropout(0.2)(conv1)
conv1 = Conv2D(32,(3,3),activation='relu',padding='same')(conv1)
pool1 = MaxPooling2D((2,2),strides=2)(conv1)
#112,112,32 -> 56,56,64
conv2 = Conv2D(64,(3,3),activation='relu',padding='same')(pool1)
conv2 = Dropout(0.2)(conv2)
conv2 = Conv2D(64,(3,3),activation='relu',padding='same')(conv2)
pool2 = MaxPooling2D((2,2),strides=2)(conv2)
#56,56,64 -> 56,56,128
conv3 = Conv2D(128,(3,3),activation='relu',padding='same')(pool2)
conv3 = Dropout(0.2)(conv3)
conv3 = Conv2D(128,(3,3),activation='relu',padding='same')(conv3)
#-------------------------------------------------
# decoder 部分
#56,56,128 -> 112,112,64
up1 = UpSampling2D(2)(conv3)
#112,112,64 -> 112,112,64+128
up1 = Concatenate(axis=-1)([up1,conv2])
# #112,112,192 -> 112,112,64
conv4 = Conv2D(64,(3,3),activation='relu',padding='same')(up1)
conv4 = Dropout(0.2)(conv4)
conv4 = Conv2D(64,(3,3),activation='relu',padding='same')(conv4)
#112,112,64 - >224,224,64
up2 = UpSampling2D(2)(conv4)
#224,224,64 -> 224,224,64+32
up2 = Concatenate(axis=-1)([up2,conv1])
# 224,224,96 -> 224,224,32
conv5 = Conv2D(32,(3,3),activation='relu',padding='same')(up2)
conv5 = Dropout(0.2)(conv5)
conv5 = Conv2D(32,(3,3),activation='relu',padding='same')(conv5)
o = Conv2D(n_classes,1,padding='same')(conv5)
return Model(img_input,o,name="unet_mini")
if __name__=="__main__":
model = unet_mini()
model.summary()
mini_unet 通过encoder 部分将 224x224x3的图像 变成 112x112x64 的特征图,再通过 上采样方法将特征图放大到 224x224x32。最后通过卷积:
o = Conv2D(n_classes,1,padding='same')(conv5)
将特征图的通道数调节成和类别数一样。
3. Mobilenet_unet
Mobilenet_unet 是使用Mobinet 作为主干特征提取网络,并且加载预训练权重来提升特征提取的能力。decoder 的还原部分和上面一致,下面是Mobilenet_unet 的网络结构:
from keras.models import *
from keras.layers import *
import keras.backend as K
import keras
from tensorflow.python.keras.backend import shape
IMAGE_ORDERING = "channels_last"# channel last
def relu6(x):
return K.relu(x, max_value=6)
def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
filters = int(filters * alpha)
x = ZeroPadding2D(padding=(1, 1), name='conv1_pad',
data_format=IMAGE_ORDERING)(inputs)
x = Conv2D(filters, kernel, data_format=IMAGE_ORDERING,
padding='valid',
use_bias=False,
strides=strides,
name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
depth_multiplier=1, strides=(1, 1), block_id=1):
channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING,
name='conv_pad_%d' % block_id)(inputs)
x = DepthwiseConv2D((3, 3), data_format=IMAGE_ORDERING,
padding='valid',
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
name='conv_dw_%d' % block_id)(x)
x = BatchNormalization(
axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
x = Conv2D(pointwise_conv_filters, (1, 1), data_format=IMAGE_ORDERING,
padding='same',
use_bias=False,
strides=(1, 1),
name='conv_pw_%d' % block_id)(x)
x = BatchNormalization(axis=channel_axis,
name='conv_pw_%d_bn' % block_id)(x)
return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
def get_mobilnet_eocoder(input_shape=(224,224,3),weights_path=""):
# 必须是32 的倍数
assert input_shape[0] % 32 == 0
assert input_shape[1] % 32 == 0
alpha = 1.0
depth_multiplier = 1
img_input = Input(shape=input_shape)
#(None, 224, 224, 3) ->(None, 112, 112, 64)
x = _conv_block(img_input, 32, alpha, strides=(2, 2))
x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
f1 = x
#(None, 112, 112, 64) -> (None, 56, 56, 128)
x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
strides=(2, 2), block_id=2)
x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
f2 = x
#(None, 56, 56, 128) -> (None, 28, 28, 256)
x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
strides=(2, 2), block_id=4)
x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
f3 = x
# (None, 28, 28, 256) -> (None, 14, 14, 512)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
strides=(2, 2), block_id=6)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
f4 = x
# (None, 14, 14, 512) -> (None, 7, 7, 1024)
x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
strides=(2, 2), block_id=12)
x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
f5 = x
# 加载预训练模型
if weights_path!="":
Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True)
# f1: (None, 112, 112, 64)
# f2: (None, 56, 56, 128)
# f3: (None, 28, 28, 256)
# f4: (None, 14, 14, 512)
# f5: (None, 7, 7, 1024)
return img_input, [f1, f2, f3, f4, f5]
def mobilenet_unet(num_classes=2,input_shape=(224,224,3)):
#encoder
img_input,levels = get_mobilnet_eocoder(input_shape=input_shape,weights_path="model_data\\mobilenet_1_0_224_tf_no_top.h5")
[f1, f2, f3, f4, f5] = levels
# f1: (None, 112, 112, 64)
# f2: (None, 56, 56, 128)
# f3: (None, 28, 28, 256)
# f4: (None, 14, 14, 512)
# f5: (None, 7, 7, 1024)
#decoder
#(None, 14, 14, 512) - > (None, 14, 14, 512)
o = f4
o = ZeroPadding2D()(o)
o = Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o)
o = BatchNormalization()(o)
#(None, 14, 14, 512) ->(None,28,28,256)
o = UpSampling2D(2)(o)
o = Concatenate(axis=-1)([o,f3])
o = ZeroPadding2D()(o)
o = Conv2D(256, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o)
o = BatchNormalization()(o)
# None,28,28,256)->(None,56,56,128)
o = UpSampling2D(2)(o)
o = Concatenate(axis=-1)([o,f2])
o = ZeroPadding2D()(o)
o = Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o)
o = BatchNormalization()(o)
#(None,56,56,128) ->(None,112,112,64)
o = UpSampling2D(2)(o)
o = Concatenate(axis=-1)([o,f1])
o = ZeroPadding2D()(o)
o = Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o)
o = BatchNormalization()(o)
#(None,112,112,64) -> (None,112,112,num_classes)
# 再上采样 让输入和出处图片大小一致
o = UpSampling2D(2)(o)
o = ZeroPadding2D()(o)
o = Conv2D(64, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING)(o)
o = BatchNormalization()(o)
o = Conv2D(num_classes, (3, 3), padding='same',
data_format=IMAGE_ORDERING)(o)
return Model(img_input,o)
if __name__=="__main__":
mobilenet_unet(input_shape=(512,512,3)).summary()
特征图的大小变化,以及代码含义都已经注释在代码里了。大家仔细阅读吧
4.数据加载部分
import math
import os
from random import shuffle
import cv2
import keras
import numpy as np
from PIL import Image
#-------------------------------
# 将图片转换为 rgb
#------------------------------
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#-------------------------------
# 图片归一化 0~1
#------------------------------
def preprocess_input(image):
image = image / 127.5 - 1
return image
#---------------------------------------------------
# 对输入图像进行resize
#---------------------------------------------------
def resize_image(image, size):
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
class UnetDataset(keras.utils.Sequence):
def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, dataset_path):
self.annotation_lines = annotation_lines
self.length = len(self.annotation_lines)
self.input_shape = input_shape
self.batch_size = batch_size
self.num_classes = num_classes
self.train 以上是关于Unet 语义分割模型(Keras)| 以细胞图像为例的主要内容,如果未能解决你的问题,请参考以下文章
为啥训练多类语义分割的unet模型中的分类交叉熵损失函数非常高?