如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量

Posted

技术标签:

【中文标题】如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量【英文标题】:How correctly calculate tf.nn.weighted_cross_entropy_with_logits pos_weight variable 【发布时间】:2017-09-19 17:48:27 【问题描述】:

我正在使用卷积神经网络。

我的数据很不平衡,我有两个类。

我的第一堂课包含:551,462 个图像文件

我的第二堂课包含:52,377 个图像文件

我想使用weighted_cross_entropy_with_logits,但我不确定我是否正确计算pos_weight 变量。

我现在正在使用

classes_weights = tf.constant([0.0949784, 1.0])
cross_entropy = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y_, pos_weight=classes_weights))
train_step = tf.train.AdamOptimizer(LEARNING_RATE, epsilon=1e-03).minimize(
      cross_entropy
    , global_step=global_step
    )

或者我应该使用

classes_weights = 10.5287

【问题讨论】:

【参考方案1】:

来自文档:

pos_weight:用于正例的系数。

参数 pos_weight 用作正目标的乘数:

所以如果你的第一堂课是肯定的,那么pos_weights = 52,377 / 551,462,否则551,462 / 52,377

【讨论】:

我是这么想的,但我看到了几个例子,人们使用类系数数组作为输入source。同样使用pos_weights = 10.5287 运行代码会使损失保持在非常高的水平。即使经过 60600 次迭代 * 50 的 mini batch 在某个时候达到平均损失超过 1.0,这似乎并不正确。而且似乎 1 类已经有了更好的准确度,而 2 类并没有提高那么好。 @DariusŠilkaitis 这就是文档所说的,我对它更信任,而不是对 SO 的一个孤独的回答。您尝试了我的方法并且对结果不满意,但是您是否尝试过另一种方法tf.constant([0.0949784, 1.0]) 在这样的大数据上训练相当慢。所以我还没有深入尝试这两种解决方案。 tf.constant([0.0949784, 1.0]) 损失对我的眼睛来说似乎太低了,但我得到了更好的准确性。我需要几天的时间来尝试这两种配置,每个配置至少 20 个 epoch。我会在这里更新结果。感谢您的帮助。 @SalvadorDali 尝试使用 pos_weight 的标量值的这种方法将所有多数类(0)分类为少数类(1),大大增加了误报。关于为什么会这样的任何线索?提前致谢。【参考方案2】:

正如@Salvador Dali 所说,最好的来源是源代码 https://github.com/tensorflow/tensorflow/blob/5b10b3474bea72e29875264bb34be476e187039c/tensorflow/python/ops/nn_impl.py#L183

我们有

log_weight = 1 + (pos_weight - 1) * targets

所以它只适用于targets==1

如果targets==0 那么log_weight = 1

如果targets==1 那么log_weight = pos_weight

因此,如果我们有正负比率 x/y,我们需要 pos_weight 为 y/x,这样两个类别的总贡献相同

请注意,目标张量中的每个标量对应于每个类别,因此 pos_weight 的每个成员也对应于每个类别(不是一个类别的正概率或负概率)。

【讨论】:

以上是关于如何正确计算 tf.nn.weighted_cross_entropy_with_logits pos_weight 变量的主要内容,如果未能解决你的问题,请参考以下文章

在图像分类中如何计算正确的标签? [关闭]

如何根据按钮框架正确计算cornerRadius

UITableview 部分 - 如何正确重新计算高度?

计算后如何正确转换小数

如何正确计算 Fisher 变换指标

如何获得对云计算的正确控制