多类分割的广义骰子损失:keras 实现
Posted
技术标签:
【中文标题】多类分割的广义骰子损失:keras 实现【英文标题】:Generalized dice loss for multi-class segmentation: keras implementation 【发布时间】:2018-08-07 07:08:55 【问题描述】:我刚刚在 keras 中实现了广义 dice loss(dice loss 的多类版本),如ref 中所述:
(我的目标定义为:(batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))
def generalized_dice_loss_w(y_true, y_pred):
# Compute weights: "the contribution of each label is corrected by the inverse of its volume"
Ncl = y_pred.shape[-1]
w = np.zeros((Ncl,))
for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
w = 1/(w**2+0.00001)
# Compute gen dice coef:
numerator = y_true*y_pred
numerator = w*K.sum(numerator,(0,1,2,3))
numerator = K.sum(numerator)
denominator = y_true+y_pred
denominator = w*K.sum(denominator,(0,1,2,3))
denominator = K.sum(denominator)
gen_dice_coef = numerator/denominator
return 1-2*gen_dice_coef
但一定有什么地方不对劲。我正在处理必须为 4 个类(1 个背景类和 3 个对象类,我有一个不平衡的数据集)分割的 3D 图像。第一件奇怪的事情:虽然我的训练损失和准确度在训练期间有所提高(并且收敛得非常快),但我的验证损失/准确度是恒定的低谷时期(参见image)。其次,在对测试数据进行预测时,只预测背景类:I get a constant volume。
我使用了完全相同的数据和脚本,但使用了分类交叉熵损失并得到了合理的结果(对象类被分段)。这意味着我的实现有问题。知道它可能是什么吗?
另外,我相信 keras 社区有一个通用的 dice loss 实现会很有用,因为它似乎被用于最近的大多数语义分割任务(至少在医学图像社区中)。
PS:对我来说权重是如何定义的似乎很奇怪;我得到大约 10^-10 的值。还有其他人尝试过实现这一点吗?我还测试了没有权重的函数,但遇到了同样的问题。
【问题讨论】:
嗨@Manu,你有想过这个吗? 【参考方案1】:我认为这里的问题是你的体重。想象一下,您正在尝试解决多类分割问题,但在每张图像中只有少数类存在。一个玩具示例(也是导致我遇到此问题的示例)是通过以下方式从 mnist 创建分段数据集。
x = 28x28 图像和 y = 28x28x11 如果每个像素低于归一化灰度值 0.4,则将其分类为背景,否则将其分类为 x 的原始类别的数字。所以如果你看到一张第一的图片,你会有一堆像素归为一个,还有背景。
现在,在这个数据集中,图像中只会出现两个类。这意味着,在您丢掉骰子后,其中的 9 个权重将
1./(0. + eps) = large
因此,对于每张图像,我们都会对所有 9 个不存在的类进行强烈惩罚。在这种情况下,网络想要找到的一个明显强的局部最小值是将所有内容预测为背景类。
我们确实希望惩罚任何不在图像中但不那么强烈的错误预测类别。所以我们只需要修改权重。我就是这样做的:
def gen_dice(y_true, y_pred, eps=1e-6):
"""both tensors are [b, h, w, classes] and y_pred is in logit form"""
# [b, h, w, classes]
pred_tensor = tf.nn.softmax(y_pred)
y_true_shape = tf.shape(y_true)
# [b, h*w, classes]
y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
# [b, classes]
# count how many of each class are present in
# each image, if there are zero, then assign
# them a fixed weight of eps
counts = tf.reduce_sum(y_true, axis=1)
weights = 1. / (counts ** 2)
weights = tf.where(tf.math.is_finite(weights), weights, eps)
multed = tf.reduce_sum(y_true * y_pred, axis=1)
summed = tf.reduce_sum(y_true + y_pred, axis=1)
# [b]
numerators = tf.reduce_sum(weights*multed, axis=-1)
denom = tf.reduce_sum(weights*summed, axis=-1)
dices = 1. - 2. * numerators / denom
dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
return tf.reduce_mean(dices)
【讨论】:
以上是关于多类分割的广义骰子损失:keras 实现的主要内容,如果未能解决你的问题,请参考以下文章