对于 batch_size>1,Keras(分割模型)的形状不兼容问题

Posted

技术标签:

【中文标题】对于 batch_size>1,Keras(分割模型)的形状不兼容问题【英文标题】:Incompatible shape problem with Keras ( segmentation model) for batch_size>1 【发布时间】:2020-12-31 05:05:16 【问题描述】:

我正在尝试使用来自segmentation model 的Unet 对多通道 (>3) 图像进行语义分割。 如果batch_size = 1,则代码有效。但是,如果我将 batch_size 更改为其他值(例如 2),则会发生错误(InvalidArgumentError: Incompatible shapes):

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-19-15dc3666afa8> in <module>
     22                                     validation_steps = 1,
     23                                     callbacks=build_callbacks(),
---> 24                                     verbose = 1)
     25 

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1424             use_multiprocessing=use_multiprocessing,
   1425             shuffle=shuffle,
-> 1426             initial_epoch=initial_epoch)
   1427 
   1428     @interfaces.legacy_generator_methods_support

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    189                 outs = model.train_on_batch(x, y,
    190                                             sample_weight=sample_weight,
--> 191                                             class_weight=class_weight)
    192 
    193                 if not isinstance(outs, list):

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1218             ins = x + y + sample_weights
   1219         self._make_train_function()
-> 1220         outputs = self.train_function(ins)
   1221         if len(outputs) == 1:
   1222             return outputs[0]

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2659                 return self._legacy_call(inputs)
   2660 
-> 2661             return self._call(inputs)
   2662         else:
   2663             if py_any(is_tensor(x) for x in inputs):

~/.virtualenvs/sm/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2629                                 symbol_vals,
   2630                                 session)
-> 2631         fetched = self._callable_fn(*array_vals)
   2632         return fetched[:len(self.outputs)]
   2633 

~/.virtualenvs/sm/lib/python3.6/site-packages/tensorflow_core/python/client/session.py in __call__(self, *args, **kwargs)
   1470         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1471                                                self._handle, args,
-> 1472                                                run_metadata_ptr)
   1473         if run_metadata:
   1474           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: Incompatible shapes: [2,256,256,1] vs. [2,256,256]
     [[node loss_1/model_4_loss/mul]]

我试图通过在论坛中关注不同的帖子来玩耍,但无法解决。这里是运行 batch_size=1 的部分代码。

batch_size = 1 # CHANGING  ‘batch_size ‘ value other than 1 gives error
train_image_files = glob(patch_img + "/**/*.tif")
# simple_image_generator() is used to work with multi channel (>3) images (the function is 
at the end)
train_image_generator = simple_image_generator(train_image_files, 
                                         batch_size=batch_size,
                                         rotation_range=45,
                                         horizontal_flip=True,
                                         vertical_flip=True)

train_mask_files = glob(patch_ann + "/**/*.tif")
train_mask_generator = simple_image_generator(train_mask_files, 
                                              batch_size=batch_size)


test_image_files = glob(test_img + "/**/*.tif")
test_image_generator = simple_image_generator(test_image_files, 
                                         batch_size=batch_size,
                                         rotation_range=45,
                                         horizontal_flip=True,
                                         vertical_flip=True)

test_mask_files = glob(test_ann + "/**/*.tif")
test_mask_generator = simple_image_generator(test_mask_files, 
                                              batch_size=batch_size)

train_generator = (pair for pair in zip(train_image_generator, train_mask_generator))
test_generator = (pair for pair in zip(test_image_generator, test_mask_generator))


.
.
num_channels = 8 # no. of channel
base_model = sm.Unet(backbone_name='resnet34', encoder_weights='imagenet')
inp = Input(shape=( None, None, num_channels))
layer_1 = Conv2D( 3, (1, 1))(inp) # map N channels data to 3 channels
out = base_model(layer_1)
model = Model(inp, out, name=base_model.name)
model.summary()

model.compile(
    optimizer = keras.optimizers.Adam(lr=learning_rate),
    loss = sm.losses.bce_jaccard_loss,
    metrics = ['accuracy',sm.metrics.iou_score]
)
model_history = model.fit_generator(train_generator, 
                                    epochs = 1, 
                                    steps_per_epoch = 1,
                                    validation_data = test_generator, 
                                    validation_steps = 1,
                                    callbacks = build_callbacks(),
                                    verbose = 1)

其他信息: 我没有使用 keras 提供的默认 imageGenerator。我使用的是‘simple_image_generator’(稍作修改)

def simple_image_generator(files, batch_size=32,
                           rotation_range=0, horizontal_flip=False,
                           vertical_flip=False):
    while True:
        # select batch_size number of samples without replacement
        batch_files = sample(files, batch_size)
       
        
        # array for images
        batch_X = []
        # loop over images of the current batch
        for idx, input_path in enumerate(batch_files):
            image = np.array(imread(input_path), dtype=float)
 
            # process image
            if horizontal_flip:
                # randomly flip image up/down
                if choice([True, False]):
                    image = np.flipud(image)
            if vertical_flip:
                # randomly flip image left/right
                if choice([True, False]):
                    image = np.fliplr(image)
            # rotate image by random angle between
            # -rotation_range <= angle < rotation_range
            if rotation_range is not 0:
                angle = np.random.uniform(low=-abs(rotation_range),
                                          high=abs(rotation_range))
                image = rotate(image, angle, mode='reflect',
                               order=1, preserve_range=True)
            # put all together
            batch_X += [image]
             
        # convert lists to np.array
        X = np.array(batch_X)
            
        yield(X)

【问题讨论】:

【参考方案1】:

通过重新定义新的图像生成器而不是 simple_image_generator() 解决了这个错误。 simple_image_generator() 可以很好地处理图像的形状(8 个波段),但不能很好地处理蒙版的形状(1 个波段)。

在执行过程中,image_generator 有 4 个维度,[2,256,256,1](即 batch_size,(图像大小),bands)但 mask_generator 只有 3 个维度,而 [2,256,256](即 batch_size,(图像大小))

因此将 [2,256,256] 的掩码重塑为 [2,256,256, 1] 解决了这个问题。

【讨论】:

以上是关于对于 batch_size>1,Keras(分割模型)的形状不兼容问题的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Keras 中建模卷积循环网络 (CRNN)

带有noise_shape的Keras Dropout

keras:如何保存历史对象的训练历史属性

在 Keras 中组合模型(输出)

tf.keras之损失函数

Keras:如何在损失函数中使用层的权重?