如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量
Posted
技术标签:
【中文标题】如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量【英文标题】:How correctly calculate tf.nn.weighted_cross_entropy_with_logits pos_weight variable 【发布时间】:2017-09-19 17:48:27 【问题描述】:我正在使用卷积神经网络。
我的数据很不平衡,我有两个类。
我的第一堂课包含:551,462 个图像文件
我的第二堂课包含:52,377 个图像文件
我想使用weighted_cross_entropy_with_logits
,但我不确定我是否正确计算pos_weight
变量。
我现在正在使用
classes_weights = tf.constant([0.0949784, 1.0])
cross_entropy = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y_, pos_weight=classes_weights))
train_step = tf.train.AdamOptimizer(LEARNING_RATE, epsilon=1e-03).minimize(
cross_entropy
, global_step=global_step
)
或者我应该使用
classes_weights = 10.5287
【问题讨论】:
【参考方案1】:来自文档:
pos_weight:用于正例的系数。
和
参数 pos_weight 用作正目标的乘数:
所以如果你的第一堂课是肯定的,那么pos_weights = 52,377 / 551,462
,否则551,462 / 52,377
【讨论】:
我是这么想的,但我看到了几个例子,人们使用类系数数组作为输入source。同样使用pos_weights = 10.5287
运行代码会使损失保持在非常高的水平。即使经过 60600 次迭代 * 50 的 mini batch 在某个时候达到平均损失超过 1.0,这似乎并不正确。而且似乎 1 类已经有了更好的准确度,而 2 类并没有提高那么好。
@DariusŠilkaitis 这就是文档所说的,我对它更信任,而不是对 SO 的一个孤独的回答。您尝试了我的方法并且对结果不满意,但是您是否尝试过另一种方法tf.constant([0.0949784, 1.0])
?
在这样的大数据上训练相当慢。所以我还没有深入尝试这两种解决方案。 tf.constant([0.0949784, 1.0])
损失对我的眼睛来说似乎太低了,但我得到了更好的准确性。我需要几天的时间来尝试这两种配置,每个配置至少 20 个 epoch。我会在这里更新结果。感谢您的帮助。
@SalvadorDali 尝试使用 pos_weight 的标量值的这种方法将所有多数类(0)分类为少数类(1),大大增加了误报。关于为什么会这样的任何线索?提前致谢。【参考方案2】:
正如@Salvador Dali 所说,最好的来源是源代码 https://github.com/tensorflow/tensorflow/blob/5b10b3474bea72e29875264bb34be476e187039c/tensorflow/python/ops/nn_impl.py#L183
我们有
log_weight = 1 + (pos_weight - 1) * targets
所以它只适用于targets==1
。
如果targets==0
那么log_weight = 1
如果targets==1
那么log_weight = pos_weight
因此,如果我们有正负比率 x/y,我们需要 pos_weight 为 y/x,这样两个类别的总贡献相同
请注意,目标张量中的每个标量对应于每个类别,因此 pos_weight 的每个成员也对应于每个类别(不是一个类别的正概率或负概率)。
【讨论】:
以上是关于如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量的主要内容,如果未能解决你的问题,请参考以下文章