基于空洞卷积的多尺度2D特征融合网络

Posted jinyiyexingzc

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于空洞卷积的多尺度2D特征融合网络相关的知识,希望对你有一定的参考价值。

  空洞卷积模块2,来替代编码网络中每一层的卷积,图3.4展示了空洞卷积模块结构,该模块共使用了五个不同尺度的卷积组,分别用扩张率1,2,4,8,16的空洞卷积进行组合,第一组为扩张率为1的3×3卷积,可以得到每一像素点感受野为3×3的特征图;第二组级联扩张率为1和2的3×3卷积,感受野为7×7;依此类推,后面每一组额外多并行一个2倍扩张率的空洞卷积,第三组、第四组、第五组卷积得到的感受野分别为15×15、31×31、63×63。经过修改的空洞卷积模块2的每一个并行的特征处理通道中,空洞卷积的参数覆盖率是百分百,不存在信息丢失的问题。在模块末端,将这五组不同尺度的特征图进行拼接,将拼接的特征图再通过一个普通卷积,上述操作实现了多尺度特征融合。

 

技术图片

 

网络训练程序

 

  1 import keras
  2 from keras.models import *
  3 from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Dropout
  4 from keras.optimizers import *
  5 
  6 from keras.layers import Concatenate
  7 
  8 from keras import backend as K
  9 
 10 import matplotlib.pyplot as plt
 11 from keras.callbacks import ModelCheckpoint
 12 from fit_generator import get_path_list, get_train_batch
 13 
 14 
 15 train_batch_size = 1
 16 epoch = 1
 17 
 18 data_train_path = "../deform/train_label_dir/train3"
 19 data_label_path = "../deform/train_label_dir/label3"
 20 
 21 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
 22 
 23 
 24 # 写一个LossHistory类,保存loss和acc
 25 class LossHistory(keras.callbacks.Callback):
 26    def on_train_begin(self, logs={}):
 27        self.losses = {batch: [], epoch:[]}
 28        self.accuracy = {batch: [], epoch:[]}
 29        self.val_loss = {batch: [], epoch:[]}
 30        self.val_acc = {batch: [], epoch:[]}
 31 
 32    def on_batch_end(self, batch, logs={}):
 33        self.losses[batch].append(logs.get(loss))
 34        self.accuracy[batch].append(logs.get(dice_coef))
 35        self.val_loss[batch].append(logs.get(val_loss))
 36        self.val_acc[batch].append(logs.get(val_acc))
 37 
 38    def on_epoch_end(self, batch, logs={}):
 39        self.losses[epoch].append(logs.get(loss))
 40        self.accuracy[epoch].append(logs.get(dice_coef))
 41        self.val_loss[epoch].append(logs.get(val_loss))
 42        self.val_acc[epoch].append(logs.get(val_acc))
 43 
 44    def loss_plot(self, loss_type):
 45        iters = range(len(self.losses[loss_type]))
 46        plt.figure(1)
 47        # acc
 48        plt.plot(iters, self.accuracy[loss_type], r, label=train dice)
 49        if loss_type == epoch:
 50            # val_acc
 51            plt.plot(iters, self.val_acc[loss_type], b, label=val acc)
 52        plt.grid(True)
 53        plt.xlabel(loss_type)
 54        plt.ylabel(dice)
 55        plt.legend(loc="best")
 56        plt.savefig(./curve_figure/tune_liver/unet_liver2_raw_0_129_dialtion_all_entropy_dice_curve2.png)
 57        
 58        plt.figure(2)
 59        # loss
 60        plt.plot(iters, self.losses[loss_type], g, label=train loss)
 61        if loss_type == epoch:
 62            # val_loss
 63            plt.plot(iters, self.val_loss[loss_type], k, label=val loss)
 64        plt.grid(True)
 65        plt.xlabel(loss_type)
 66        plt.ylabel(loss)
 67        plt.legend(loc="best")
 68        plt.savefig(./curve_figure/tune_liver/unet_liver2_raw_0_129_dialtion_all_entropy_loss_curve2.png)
 69        plt.show()
 70 
 71 
 72 def dice_coef(y_true, y_pred):
 73     smooth = 1.
 74     y_true_f = K.flatten(y_true)
 75     y_pred_f = K.flatten(y_pred)
 76     intersection = K.sum(y_true_f * y_pred_f)
 77     return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
 78 
 79 
 80 def dice_coef_loss(y_true, y_pred):
 81     return 1. - dice_coef(y_true, y_pred)
 82 
 83 
 84 def mycrossentropy(y_true, y_pred, e=0.1):
 85     nb_classes = 10
 86     loss1 = K.categorical_crossentropy(y_true, y_pred)
 87     loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
 88     return (1 - e) * loss1 + e * loss2
 89 
 90 
 91 class myUnet(object):
 92     def __init__(self, img_rows=512, img_cols=512):
 93         self.img_rows = img_rows
 94         self.img_cols = img_cols
 95 
 96     def dilation_conv(self, kernel_num, kernel_size, input):
 97         # 空洞卷积多尺度模块
 98         conv1_1 = Conv2D(kernel_num, kernel_size, activation=relu, padding=same, kernel_initializer=he_normal,
 99                          dilation_rate=(1, 1))(input)
100         conv1_2 = Conv2D(kernel_num, kernel_size, activation=relu, padding=same, kernel_initializer=he_normal,
101                          dilation_rate=(2, 2))(input)
102         conv1_3 = Conv2D(kernel_num, kernel_size, activation=relu, padding=same, kernel_initializer=he_normal,
103                          dilation_rate=(4, 4))(input)
104         conv1_4 = Conv2D(kernel_num, kernel_size, activation=relu, padding=same, kernel_initializer=he_normal,
105                          dilation_rate=(8, 8))(input)
106         conv1_5 = Conv2D(kernel_num, kernel_size, activation=relu, padding=same, kernel_initializer=he_normal,
107                          dilation_rate=(16, 16))(input)
108         merges = Concatenate(axis=3)([conv1_1, conv1_2, conv1_3, conv1_4, conv1_5])
109         return merges
110 
111     def BN_operation(self, input):
112         output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
113                                                                scale=True,
114                                                                beta_initializer=zeros, gamma_initializer=ones,
115                                                                moving_mean_initializer=zeros,
116                                                                moving_variance_initializer=ones,
117                                                                beta_regularizer=None,
118                                                                gamma_regularizer=None, beta_constraint=None,
119                                                                gamma_constraint=None)(input)
120         return output
121 
122     def get_unet(self):
123         inputs = Input((self.img_rows, self.img_cols, 1))
124 
125         conv1 = self.dilation_conv(64, 3, inputs)
126         conv1 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv1)
127         pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
128         # BN
129         pool1 = self.BN_operation(pool1)
130 
131         conv2 = self.dilation_conv(128, 3, pool1)
132         conv2 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv2)
133         pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
134         # BN
135         pool2 = self.BN_operation(pool2)
136 
137         conv3 = self.dilation_conv(256, 3, pool2)
138         conv3 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv3)
139         pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
140         # BN
141         pool3 = self.BN_operation(pool3)
142 
143         conv4 = self.dilation_conv(512, 3, pool3)
144         conv4 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv4)
145         drop4 = Dropout(0.5)(conv4)
146         pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
147         # BN
148         pool4 = self.BN_operation(pool4)
149 
150         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool4)
151         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv5)
152         drop5 = Dropout(0.5)(conv5)
153         # BN
154         drop5 = self.BN_operation(drop5)
155 
156         up6 = Conv2D(512, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
157             UpSampling2D(size=(2, 2))(drop5))
158         merge6 = Concatenate(axis=3)([drop4, up6])
159         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge6)
160         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv6)
161 
162         up7 = Conv2D(256, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
163             UpSampling2D(size=(2, 2))(conv6))
164         merge7 = Concatenate(axis=3)([conv3, up7])
165         conv7 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge7)
166         conv7 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv7)
167 
168         up8 = Conv2D(128, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
169             UpSampling2D(size=(2, 2))(conv7))
170         merge8 = Concatenate(axis=3)([conv2, up8])
171         conv8 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge8)
172         conv8 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv8)
173 
174         up9 = Conv2D(64, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
175             UpSampling2D(size=(2, 2))(conv8))
176         merge9 = Concatenate(axis=3)([conv1, up9])
177         conv9 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge9)
178         conv9 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv9)
179         conv9 = Conv2D(2, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv9)
180         conv10 = Conv2D(1, 1, activation=sigmoid)(conv9)
181 
182         model = Model(inputs=inputs, outputs=conv10)
183 
184         # 在这里可以自定义损失函数loss和准确率函数accuracy
185         # model.compile(optimizer=Adam(lr=1e-4), loss=‘binary_crossentropy‘, metrics=[‘accuracy‘])
186         model.compile(optimizer=Adam(lr=1e-4), loss= binary_crossentropy, metrics=[accuracy, dice_coef])
187         print(model compile)
188         return model
189 
190     def train(self):
191 #        model = self.get_unet()
192         model = load_model(./model/dilation_tune/unet_liver2_dir1_dilation_all_entropy.hdf5, custom_objects={dice_coef: dice_coef})
193 
194         
195         print("got unet")
196 
197         # 保存的是模型和权重
198         model_checkpoint = ModelCheckpoint(./model/dilation_tune/unet_liver2_dir2_dilation_all_entropy.hdf5, monitor=loss,
199                                            verbose=1, save_best_only=True)
200         print(Fitting model...)
201 
202        # 创建一个实例history
203         history = LossHistory()
204         model.fit_generator(
205             generator=get_train_batch(train_path_list, label_path_list, train_batch_size, 512, 512),
206             epochs=epoch, verbose=1,
207             steps_per_epoch=count//train_batch_size,
208             callbacks=[model_checkpoint, history],
209             workers=1)
210 
211         # 绘制acc-loss曲线
212         history.loss_plot(batch)
213 
214 
215 if __name__ == __main__:
216     myunet = myUnet()
217     myunet.train()

 

模型测试、模型评估,同U-Net网络

 

以上是关于基于空洞卷积的多尺度2D特征融合网络的主要内容,如果未能解决你的问题,请参考以下文章

目标检测之多尺度融合

论文解读丨空洞卷积框架搜索

多尺度空间-光谱相互作用Transformer:Pan-Sharpening

深度卷积神经网络各种改进结构块汇总

研究成果|编码-解码多尺度卷积神经网络人群计数方法

EfficientDet:Scalable and Efficient Object Detection