tf.nn.softmax_cross_entropy_with_logits

Posted nowgood

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.nn.softmax_cross_entropy_with_logits相关的知识,希望对你有一定的参考价值。

函数功能: 计算 logits 与 labels 的 softmax 交叉熵.

函数定义

def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
                                      labels=None,
                                      logits=None,
                                      dim=-1, name=None)                              

参数详解:

  1. # pylint: disable=invalid-name 作用是关闭该行代码中告警消息
  2. _sentinel:该参数为内部使用,实际运用中不会使用。目的是占据函数参数的第一个位置,让我们不能通过位置传递labels 和 logits,而是使用关键字传递参数,避免混淆。原理是:实际使用时_sentinel不用传入该参数,所以我们想要通过位置传递 labels 和 logits 时,实际上参数传递情况为传给了 _sentinel, labels。函数内部检查 logits没有值传进来,就会报错。
  3. labels: 实际标签
  4. logits: 对数几率,需要注意的是这个函数内部自动计算 softmax,然后再计算交叉熵代价函数,也就是说 logits 必须是没有经过 tf.nn.softmax 函数处理的数据,否则导致训练结果有问题。

notice!!! Measures the probability error in discrete classification tasks in which the
classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.

执行步骤

具体的执行流程大概分为两步:

第一步:
是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个 num_classes 大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

第二步:
softmax 的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:

其中指代实际的标签中第 i 个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)
就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值
显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

tf.nn.sparse_softmax_cross_entropy_with_logits

tf.nn.sparse_softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None)

这个函数与上一个函数十分类似,唯一的区别在于labels. labels的每一行为真实类别的索引

以上是关于tf.nn.softmax_cross_entropy_with_logits的主要内容,如果未能解决你的问题,请参考以下文章