Keras 自定义损失函数 - 生存分析截尾
Posted
技术标签:
【中文标题】Keras 自定义损失函数 - 生存分析截尾【英文标题】:Keras Custom Loss Function - Survival Analysis Censored 【发布时间】:2021-06-08 16:39:58 【问题描述】:我正在尝试在 Keras 中构建一个自定义损失函数 - 这是用于生存分析中的删失数据。
这个损失函数本质上是二元交叉熵,即多标签分类,但是损失函数中的求和项需要根据 Y_true 的标签可用性而变化。请参阅下面的示例:
示例 1:适用于 Y_True 的所有标签
Y_true = [0, 0, 0, 1, 1]
Y_pred = [0.1, 0.2, 0.2, 0.8, 0.7]
损失 = -1/5(log(0.9) + log(0.8) + log(0.8) + log(0.8) + log(0.7)) = 0.22
示例 2:只有两个标签可用于 Y_True
Y_true = [0, 0, -, -, -]
Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]
损失 = -1/2 (log(0.9) + log(0.8)) = 0.164
示例 3:只有一个标签可用于 Y_True
Y_true = [0, -, -, -, -]
Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]
损失 = -1 (log(0.9)) = 0.105
在示例一的情况下,我们的损失将通过上面的公式计算,K = 5。在示例二中,我们的损失将通过 K = 2 计算(即仅根据基本事实)。损失函数需要根据 Y_true 可用性进行调整。
我尝试过自定义 Keras 损失函数...但是我正在努力研究如何根据张量流中的 nan 索引进行过滤。有人对上述自定义损失函数的编码有什么建议吗?
def nan_binary_cross_entropy(y_actual, y_predicted):
stack = tf.stack((tf.is_nan(y_actual), tf.is_nan(y_predicted)),axis=1)
is_nans = K.any(stack, axis=1)
per_instance = tf.where(is_nans, tf.zeros_like(y_actual),
tf.square(tf.subtract(y_predicted, y_actual)))
FILTER HERE
return K.binary_crossentropy(y_filt, y_filt), axis=-1)
【问题讨论】:
this是同一个损失函数吗? 不幸的是,我实际上是在尝试编写损失函数,以便考虑的标签数量根据基本事实中的条目数量而变化!不过还是谢谢...! 我明白了。您想要进行二元交叉熵,其中可用标签的数量可能因训练实例而异(对吗?)。您能否发布一个示例,让您手动进行计算以显示所需的输出? 正确!好的,让我尝试添加一个示例:) 好的,还添加了另一个示例 :) 【参考方案1】:您可以使用tf.math.is_nan
和tf.math.multiply_no_nan
的组合来掩盖您的y_true
以获得所需的结果。
import numpy as np
import tensorflow as tf
y_true = tf.constant([
[0.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, np.nan, np.nan, np.nan],
[0.0, np.nan, np.nan, np.nan, np.nan],
])
y_pred = tf.constant([
[0.1, 0.2, 0.2, 0.8, 0.7],
[0.1, 0.2, 0.1, 0.9, 0.9],
[0.1, 0.2, 0.1, 0.9, 0.9],
])
def survival_loss_fn(y_true, y_pred):
# create a mask for NaN elements
mask = tf.cast(~tf.math.is_nan(y_true), tf.float32)
# sum along the row axis of the mask to find the `N`
# for each training instance
Ns = tf.math.reduce_sum(mask, 1)
# use `multiply_no_nan` to zero out the NaN in `y_pred`
fst = tf.math.multiply_no_nan(y_true, mask) * tf.math.log(y_pred)
snd = tf.math.multiply_no_nan(1.0 - y_true, mask) * tf.math.log(1.0 - y_pred)
return -tf.math.reduce_sum(fst + snd, 1) / Ns
survival_loss_fn(y_true, y_pred)
# <tf.Tensor: shape=(3,), [0.22629324, 0.16425204, 0.10536055], dtype=float32)>
【讨论】:
嘿-关于在训练/权重更新期间损失何时减少,然后越来越接近 0,然后是 NaN 的任何建议。这可能与损失函数有关吗?降低学习率有帮助...! 将enable_check_numerics 添加到您的训练脚本中是一个不错的工具。另外我会检查训练数据是否已损坏并且是否包含 NaN。以上是关于Keras 自定义损失函数 - 生存分析截尾的主要内容,如果未能解决你的问题,请参考以下文章