语义分割 Keras 的交叉熵损失
Posted
技术标签:
【中文标题】语义分割 Keras 的交叉熵损失【英文标题】:Cross Entropy Loss for Semantic Segmentation Keras 【发布时间】:2017-06-26 09:46:35 【问题描述】:我很确定这是一个愚蠢的问题,但我在其他任何地方都找不到,所以我要在这里问。
我正在使用具有 7 个标签的 keras 中的 cnn (unet) 进行语义图像分割。因此,我使用 theano 后端对每个图像的标签是 (7,n_rows,n_cols)。因此,在每个像素的 7 层中,它是 one-hot 编码的。在这种情况下,使用分类交叉熵的错误函数是否正确?对我来说似乎是这样,但网络似乎通过二元交叉熵损失学习得更好。有人可以解释为什么会这样以及原则目标是什么?
【问题讨论】:
【参考方案1】:二进制交叉熵损失应该与最后一层的sigmod
激活一起使用,它会严重惩罚相反的预测。它没有考虑到输出是单热编码的,并且预测的总和应该是 1。但是由于错误预测严重惩罚了模型,因此在某种程度上学会了正确分类。
现在强制执行 one-hot 代码的先验是使用具有分类交叉熵的 softmax
激活。这是你应该使用的。
现在问题是在您的情况下使用softmax
,因为 Keras 不支持每个像素上的 softmax。
最简单的方法是使用Permute
层将维度置换为 (n_rows,n_cols,7),然后使用Reshape
层将其重塑为 (n_rows*n_cols,7)。然后您可以添加softmax
激活层并使用交叉熵损失。数据也应该相应地重新调整。
另一种方法是实现depth-softmax:
def depth_softmax(matrix):
sigmoid = lambda x: 1 / (1 + K.exp(-x))
sigmoided_matrix = sigmoid(matrix)
softmax_matrix = sigmoided_matrix / K.sum(sigmoided_matrix, axis=0)
return softmax_matrix
并将其用作 lambda 层:
model.add(Deconvolution2D(7, 1, 1, border_mode='same', output_shape=(7,n_rows,n_cols)))
model.add(Permute(2,3,1))
model.add(BatchNormalization())
model.add(Lambda(depth_softmax))
如果使用tf
image_dim_ordering
,那么您可以使用Permute
层。
更多参考请查看here。
【讨论】:
感谢您非常详细的回答!我使用了 reshape、softmax 和分类交叉熵。您是否期望这两种方法在速度或最终准确性方面存在任何实质性的性能差异?再次感谢! 我自己没有处理过这个场景,但你可以检查它们。您还可以尝试的另一件事是首先创建一个模型,其最后一层为sigmoid
和二元交叉熵损失,一旦训练完成,替换顶层并以softmax
结束,并使用分类交叉熵重新训练。第二次训练会很快收敛,但我敢打赌总体上训练时间会减少并且会有更好的准确性。
嗨,indraforyou,我也在研究语义分割案例。蒙版图像表示为 (1,n_rows, n_cols)。对于这种情况,我可以使用 sigmoid 和二元交叉熵吗?是否有任何具体的程序要包括在内?
@user297850 对于单通道,您不需要做任何特别的事情。您可以简单地使用 sigmoid 和二元交叉熵。
@indraforyou 我对你的功能有点困惑。 axis=0
表示您沿行取总和?我假设它是axis=-1
,因为您想沿深度求和(但是,用axis=0
代替axis=1
不起作用,它会引发错误)。另外,这真的会导致softmax吗?我会假设它是lambda x: K.exp(x)
...【参考方案2】:
我测试了@indraforyou 的解决方案,认为提出的方法有一些错误。由于 cmetsection 不允许正确的代码段,我认为这是固定版本:
def depth_softmax(matrix):
from keras import backend as K
exp_matrix = K.exp(matrix)
softmax_matrix = exp_matrix / K.expand_dims(K.sum(exp_matrix, axis=-1), axis=-1)
return softmax_matrix
此方法将期望矩阵的顺序为(高度、宽度、通道)。
【讨论】:
以上是关于语义分割 Keras 的交叉熵损失的主要内容,如果未能解决你的问题,请参考以下文章