tf.keras.losses.categorical_crossentropy 返回错误值

Posted

技术标签:

【中文标题】tf.keras.losses.categorical_crossentropy 返回错误值【英文标题】:tf.keras.losses.categorical_crossentropy returning wrong value 【发布时间】:2020-10-05 08:47:14 【问题描述】:

我有

y_true = 16

y_pred = array([1.1868494e-08, 1.8747659e-09, 1.2777099e-11, 3.6140797e-08,
                6.5852622e-11, 2.2888577e-10, 1.4515833e-09, 2.8392664e-09,
                4.7054605e-10, 9.5605066e-11, 9.3647139e-13, 2.6149302e-10,
                2.5338919e-14, 4.8815413e-10, 3.9381631e-14, 2.1434269e-06,
                9.9999785e-01, 3.0857247e-08, 1.3536775e-09, 4.6811921e-10,
                3.0638234e-10, 2.0818169e-09, 2.9950772e-10, 1.0457132e-10,
                3.2959850e-11, 3.4232595e-10, 5.1689473e-12], dtype=float32)

当我使用tf.keras.losses.categorical_crossentropy(to_categorical(y_true,num_classes=27),y_pred,from_logits=True)

我得到的损失值是2.3575358

但是如果我使用分类交叉熵的公式来获得损失值

-np.sum(to_categorical(gtp_out_true[0],num_classes=27)*np.log(gtp_pred[0]))

根据公式

我得到了2.1457695e-06的值

现在,我的问题是,为什么函数 tf.keras.losses.categorical_crossentropy 给出不同的值。

奇怪的是,我的模型给出了 100% 的准确率,即使损失停留在 2.3575。 下图是训练期间的准确率和损失图。

Tensorflow 使用什么公式计算分类交叉熵?

【问题讨论】:

【参考方案1】:

找到问题所在

我在最后一层使用了 softmax 激活

output = Dense(NUM_CLASSES, activation='softmax')(x)

但我在tf.keras.losses.categorical_crossentropy 中使用了from_logits=True,这导致softmax 再次应用于最后一层(已经是softmax(logits))的输出。所以,我传递给损失函数的output 参数是softmax(softmax(logits))

因此,损失值的异常。

当在最后一层使用softmax作为激活时,我们应该使用from_logits=False

【讨论】:

【参考方案2】:

y_pred 作为概率向量,所以你不应该使用from_logits=True。将其设置为False,您会得到:

>>> print(categorical_crossentropy(to_categorical(16, num_classes = 27),
                                   y_pred, from_logits = False).numpy())
2.264979e-06

我相信它不等于预期的 2.1457695e-06 的原因是因为 y_pred[16] 非常接近 1.0 并且categorical_crossentropy 增加了一些平滑度。

在此处查看有关 logits 的讨论的答案:What is the meaning of the word logits in TensorFlow?

如果每个输入值只能有一个标签,你也可以使用函数的稀疏版本:

print(sparse_categorical_crossentropy(16, y_pred))

【讨论】:

以上是关于tf.keras.losses.categorical_crossentropy 返回错误值的主要内容,如果未能解决你的问题,请参考以下文章