Keras 自定义损失函数 - 生存分析截尾

Posted

技术标签:

【中文标题】Keras 自定义损失函数 - 生存分析截尾【英文标题】:Keras Custom Loss Function - Survival Analysis Censored 【发布时间】:2021-06-08 16:39:58 【问题描述】:

我正在尝试在 Keras 中构建一个自定义损失函数 - 这是用于生存分析中的删失数据。

这个损失函数本质上是二元交叉熵,即多标签分类,但是损失函数中的求和项需要根据 Y_true 的标签可用性而变化。请参阅下面的示例:

示例 1:适用于 Y_True 的所有标签

Y_true = [0, 0, 0, 1, 1]

Y_pred = [0.1, 0.2, 0.2, 0.8, 0.7]

损失 = -1/5(log(0.9) + log(0.8) + log(0.8) + log(0.8) + log(0.7)) = 0.22

示例 2:只有两个标签可用于 Y_True

Y_true = [0, 0, -, -, -]

Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]

损失 = -1/2 (log(0.9) + log(0.8)) = 0.164

示例 3:只有一个标签可用于 Y_True

Y_true = [0, -, -, -, -]

Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]

损失 = -1 (log(0.9)) = 0.105

在示例一的情况下,我们的损失将通过上面的公式计算,K = 5。在示例二中,我们的损失将通过 K = 2 计算(即仅根据基本事实)。损失函数需要根据 Y_true 可用性进行调整。

我尝试过自定义 Keras 损失函数...但是我正在努力研究如何根据张量流中的 nan 索引进行过滤。有人对上述自定义损失函数的编码有什么建议吗?

def nan_binary_cross_entropy(y_actual, y_predicted):
    stack = tf.stack((tf.is_nan(y_actual), tf.is_nan(y_predicted)),axis=1)
    is_nans = K.any(stack, axis=1)
    per_instance = tf.where(is_nans, tf.zeros_like(y_actual), 
                            tf.square(tf.subtract(y_predicted, y_actual)))

    FILTER HERE
    return K.binary_crossentropy(y_filt, y_filt), axis=-1)

【问题讨论】:

this是同一个损失函数吗? 不幸的是,我实际上是在尝试编写损失函数,以便考虑的标签数量根据基本事实中的条目数量而变化!不过还是谢谢...! 我明白了。您想要进行二元交叉熵,其中可用标签的数量可能因训练实例而异(对吗?)。您能否发布一个示例,让您手动进行计算以显示所需的输出? 正确!好的,让我尝试添加一个示例:) 好的,还添加了另一个示例 :) 【参考方案1】:

您可以使用tf.math.is_nantf.math.multiply_no_nan 的组合来掩盖您的y_true 以获得所需的结果。

import numpy as np
import tensorflow as tf


y_true = tf.constant([
    [0.0, 0.0, 0.0, 1.0, 1.0],
    [0.0, 0.0, np.nan, np.nan, np.nan],
    [0.0, np.nan, np.nan, np.nan, np.nan],
])


y_pred = tf.constant([
    [0.1, 0.2, 0.2, 0.8, 0.7],
    [0.1, 0.2, 0.1, 0.9, 0.9],
    [0.1, 0.2, 0.1, 0.9, 0.9],
])


def survival_loss_fn(y_true, y_pred):
    # create a mask for NaN elements
    mask = tf.cast(~tf.math.is_nan(y_true), tf.float32)
    # sum along the row axis of the mask to find the `N`
    # for each training instance
    Ns = tf.math.reduce_sum(mask, 1)
    # use `multiply_no_nan` to zero out the NaN in `y_pred`
    fst = tf.math.multiply_no_nan(y_true, mask) * tf.math.log(y_pred)
    snd = tf.math.multiply_no_nan(1.0 - y_true, mask) * tf.math.log(1.0 - y_pred)
    return -tf.math.reduce_sum(fst + snd, 1) / Ns


survival_loss_fn(y_true, y_pred)
# <tf.Tensor: shape=(3,), [0.22629324, 0.16425204, 0.10536055], dtype=float32)>

【讨论】:

嘿-关于在训练/权重更新期间损失何时减少,然后越来越接近 0,然后是 NaN 的任何建议。这可能与损失函数有关吗?降低学习率有帮助...! 将enable_check_numerics 添加到您的训练脚本中是一个不错的工具。另外我会检查训练数据是否已损坏并且是否包含 NaN。

以上是关于Keras 自定义损失函数 - 生存分析截尾的主要内容,如果未能解决你的问题,请参考以下文章

在 Keras 中构建自定义损失函数

Keras 中的自定义损失函数(IoU 损失函数)和梯度误差?

Keras 自定义损失函数

Keras 中基于输入数据的自定义损失函数

Keras 上的自定义损失函数

Keras 自定义损失函数 dtype 错误