损失函数Center Loss 代码解析

Posted Image Process

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了损失函数Center Loss 代码解析相关的知识,希望对你有一定的参考价值。

center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。 
论文链接:http://ydwen.github.io/papers/WenECCV16.pdf 
代码链接:https://github.com/davidsandberg/facenet

理论解析请参看 https://blog.csdn.net/u014380165/article/details/76946339

下面给出centerloss的计算公式以及更新公式

 

 

下面的代码是facenet作者利用tensorflow实现的centerloss代码

def center_loss(features, label, alfa, nrof_classes):
    """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
       (http://ydwen.github.io/papers/WenECCV16.pdf)
       https://blog.csdn.net/u014380165/article/details/76946339
    """
    nrof_features = features.get_shape()[1]
  #训练过程中,需要保存当前所有类中心的全连接预测特征centers, 每个batch的计算都要先读取已经保存的centers centers
= tf.get_variable(\'centers\', [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label)#获取当前batch对应的类中心特征 diff = (1 - alfa) * (centers_batch - features)#计算当前的类中心与特征的差异,用于Cj的的梯度更新,这里facenet的作者做了一个 1-alfa操作,比较奇怪,和原论文不同 centers = tf.scatter_sub(centers, label, diff)#更新梯度Cj,对于上图中步骤6,tensorflow会将该变量centers保留下来,用于计算下一个batch的centerloss loss = tf.reduce_mean(tf.square(features - centers_batch))#计算当前的centerloss 对应于Lc return loss, centers

 

以上是关于损失函数Center Loss 代码解析的主要内容,如果未能解决你的问题,请参考以下文章

sklearn基于make_scorer函数为Logistic模型构建自定义损失函数+代码实战(二元交叉熵损失 binary cross-entropy loss)

Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码

yolov5 loss函数理解

4 关于word2vec的skip-gram模型使用负例采样nce_loss损失函数的源码剖析

如何在tensorflow,DNNLinearCombinedClassifier中使用不同的损失函数

Center Loss - A Discriminative Feature Learning Approach for Deep Face Recognition