深度学习 UNet 收敛
Posted
技术标签:
【中文标题】深度学习 UNet 收敛【英文标题】:Deep Learning UNet convergence 【发布时间】:2019-09-25 23:04:19 【问题描述】:我正在编写一个深度学习 UNet 模型,用于 RGB 256 * 256p 图像 -> 灰度图像的图像分割 灵感来自 https://github.com/zhixuhao/unet, 所以我的神经网络有以下结构:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 256, 256, 16) 448 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256, 256, 16) 64 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 256, 256, 16) 2320 batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 16) 64 conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 128, 128, 16) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 32) 4640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 128, 128, 32) 9248 batch_normalization_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 32) 128 conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 64, 64, 32) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 64) 18496 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 64) 256 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 64) 36928 batch_normalization_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 64) 256 conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 32, 32, 64) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 128) 73856 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 32, 32, 128) 512 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 128) 147584 batch_normalization_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 128) 512 conv2d_8[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 32, 32, 128) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 16, 16, 128) 0 dropout_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 256) 295168 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 16, 16, 256) 1024 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 256) 590080 batch_normalization_9[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 256) 1024 conv2d_10[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 16, 16, 256) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 32, 32, 256) 0 dropout_2[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 128) 131200 up_sampling2d_1[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 32, 32, 256) 0 dropout_1[0][0]
conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 128) 295040 concatenate_1[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 32, 32, 128) 512 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 32, 32, 128) 147584 batch_normalization_11[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 32, 32, 128) 512 conv2d_13[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 64, 64, 128) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 64, 64, 64) 32832 up_sampling2d_2[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 64, 64, 128) 0 conv2d_6[0][0]
conv2d_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 64, 64, 64) 73792 concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 64, 64, 64) 256 conv2d_15[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 64, 64, 64) 36928 batch_normalization_13[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 64, 64, 64) 256 conv2d_16[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 128, 128, 64) 0 batch_normalization_14[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 128, 128, 32) 8224 up_sampling2d_3[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 128, 128, 64) 0 conv2d_4[0][0]
conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 128, 128, 32) 18464 concatenate_3[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 128, 128, 32) 128 conv2d_18[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 128, 128, 32) 9248 batch_normalization_15[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 32) 128 conv2d_19[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 256, 256, 32) 0 batch_normalization_16[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 256, 256, 16) 2064 up_sampling2d_4[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 256, 256, 32) 0 conv2d_2[0][0]
conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 256, 256, 16) 4624 concatenate_4[0][0]
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 256, 256, 16) 64 conv2d_21[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 256, 256, 16) 2320 batch_normalization_17[0][0]
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 256, 256, 16) 64 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 256, 256, 2) 290 batch_normalization_18[0][0]
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 256, 256, 2) 8 conv2d_23[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 256, 256, 2) 0 batch_normalization_19[0][0]
__________________________________________________________________________________________________
MLP_layer (Conv2D) (None, 256, 256, 1) 3 dropout_3[0][0]
==================================================================================================
但是,收敛非常困难,它只适用于非常有限的参数集: - 学习率不大于 1e-3,在某些文章中使用 1e-2 和 Decay - 第一个卷积过滤器编号仅适用于 16(下一层 32,等等...) - 批量大小 8 或 16,而 32 和 64 不起作用 - batch_normalization 是必需的,而不是在示例基本模型中。这应该有助于网络以更少的限制参数学习...https://towardsdatascience.com/batch-normalization-theory-and-how-to-use-it-with-tensorflow-1892ca0173ad? https://arxiv.org/pdf/1502.03167.pdf
另一个细节:
- 我检查了我的输入是 np.float32
,范围从 0 到 1
- 我正在努力学习卫星图像地籍
所以我的问题是:
为什么我的网络不能使用参考文章中使用的相同参数?
-> 我必须设置“慢”参数才能使其工作(更低的学习率、更低的批量大小、更少的卷积层......)。否则它会输出具有单个像素值的灰度图像,
使用的代码:
SHAPE=256
DIM=3
INITIALIZER='glorot_uniform'
BASE_SIZE=16
LR=0.001
def get_model(pretrained_model: str = None, input_size: tuple_int = (SHAPE, SHAPE, DIM)) -> Sequential:
"""
Machine learning model for image learning, here the purpose is segmentation,
thus there should be upsampling !!
Parameters
----------
pretrained_model: str
name of .hdf5 file containing pretrained weights, syntax: 'dir:weight.hfd5'
input_size: tuple_int
Returns
-------
Sequential
"""
if pretrained_model:
return read_model(pretrained_model)
else:
inputs = Input(input_size)
conv1 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(inputs)
batch_norm1 = BatchNormalization()(conv1)
conv2 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm1)
batch_norm2 = BatchNormalization()(conv2)
pool1 = MaxPooling2D(pool_size=(2, 2))(batch_norm2)
conv3 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool1)
batch_norm3 = BatchNormalization()(conv3)
conv4 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm3)
batch_norm4 = BatchNormalization()(conv4)
pool2 = MaxPooling2D(pool_size=(2, 2))(batch_norm4)
conv5 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool2)
batch_norm5 = BatchNormalization()(conv5)
conv6 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm5)
batch_norm6 = BatchNormalization()(conv6)
pool3 = MaxPooling2D(pool_size=(2, 2))(batch_norm6)
conv7 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool3)
batch_norm7 = BatchNormalization()(conv7)
conv8 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm7)
batch_norm8 = BatchNormalization()(conv8)
drop4 = Dropout(0.2)(batch_norm8)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv9 = Conv2D(BASE_SIZE * 16, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(pool4)
batch_norm9 = BatchNormalization()(conv9)
conv10 = Conv2D(BASE_SIZE * 16, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
batch_norm9)
batch_norm10 = BatchNormalization()(conv10)
drop5 = Dropout(0.5)(batch_norm10)
up6 = Conv2D(BASE_SIZE * 8, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
UpSampling2D(size=(2, 2))(drop5))
merge6 = concatenate([drop4, up6], axis=3)
conv11 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge6)
batch_norm11 = BatchNormalization()(conv11)
conv12 = Conv2D(BASE_SIZE * 8, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
batch_norm11)
batch_norm12 = BatchNormalization()(conv12)
up7 = Conv2D(BASE_SIZE * 4, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
UpSampling2D(size=(2, 2))(batch_norm12))
merge7 = concatenate([conv6, up7], axis=3)
conv13 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge7)
batch_norm13 = BatchNormalization()(conv13)
conv14 = Conv2D(BASE_SIZE * 4, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
batch_norm13)
batch_norm14 = BatchNormalization()(conv14)
up8 = Conv2D(BASE_SIZE * 2, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
UpSampling2D(size=(2, 2))(batch_norm14))
merge8 = concatenate([conv4, up8], axis=3)
conv15 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge8)
batch_norm15 = BatchNormalization()(conv15)
conv16 = Conv2D(BASE_SIZE * 2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
batch_norm15)
batch_norm16 = BatchNormalization()(conv16)
up9 = Conv2D(BASE_SIZE, 2, activation='relu', padding='same', kernel_initializer=INITIALIZER)(
UpSampling2D(size=(2, 2))(batch_norm16))
merge9 = concatenate([conv2, up9], axis=3)
conv17 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(merge9)
batch_norm17 = BatchNormalization()(conv17)
conv18 = Conv2D(BASE_SIZE, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm17)
batch_norm18 = BatchNormalization()(conv18)
conv19 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer=INITIALIZER)(batch_norm18)
batch_norm19 = BatchNormalization()(conv19)
# personall add
drop4 = Dropout(0.2)(batch_norm19)
conv10 = Conv2D(1, 1, activation='sigmoid', name='MLP_layer')(drop4)
model = Model(input=inputs, output=conv10)
model.compile(optimizer=Adam(lr=LR),
loss='binary_crossentropy',
metrics=['accuracy', iou_loss])
return model
谢谢
【问题讨论】:
在使用 Unet 时遇到了类似的问题。BatchNormalization
是必不可少的,batch_sizes
和学习率也很低。也会对正确的答案感兴趣。
您的 UNet 的目标是什么?图像分割也是? @阿纳金
是的。关于卫星数据。
@Anakin project.inria.fr/aerialimagelabeling ?这是“我要解决的挑战”
你能发布用于创建网络的代码吗?你确定在最后一层使用 sigmoid 吗?
【参考方案1】:
二元交叉熵不能很好地解决分段问题,尤其是在您存在类别不平衡的情况下。例如,如果掩码平均包含比白色像素多得多的黑色像素,那么您的神经网络感觉可以将所有内容预测为黑色。尝试使用 Dice 损失或 Jaccard 损失作为您的目标函数,或者您可以使用具有二元交叉熵或加权二元交叉熵的 Dice 或 Jaccard 之和。最后,您可以看看这个库https://segmentation-models.readthedocs.io/en/latest/install.html,其中包含一些分割模型,包括具有不同预训练编码器的 Unet 和该主题最常见的指标(例如 Dice 和 Jaccard)。
【讨论】:
以上是关于深度学习 UNet 收敛的主要内容,如果未能解决你的问题,请参考以下文章
「深度学习一遍过」必修24:基于UNet的Semantic Segmentation
「深度学习一遍过」必修24:基于UNet的Semantic Segmentation