loss函数之triplet loss

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了loss函数之triplet loss相关的知识,希望对你有一定的参考价值。

参考技术A 不同于交叉熵损失仅仅考虑样本与类别标签之间误差,triplet loss关注样本与其他样本之间距离。来自论文 Learning local feature descriptors with triplets and shallow convolutional neural networks

对于包含 个样本的batch数据 。 第 个样本对应的 ,如下:

其中, , , ,分别代表锚点,正例(与锚点同类)和负例(与锚点不同类)。距离函数 , 用于度量锚点与正例负例之间的距离。 是人为设置的常数。最小化损失函数,使得锚点与正例的距离越小,与负例的距离越大。

由以上公式可知,

(1) 当 ,即 , 该样本对应的 为0。

此时,锚点和负例的距离大于锚点和正例的距离,并且差值大于 。 对于这样的锚点被认为是易分类样本,直接忽略其带来的误差,从而加速计算。

(2) 当 , 该样本对应的 为 , 分为两种情况:

pytorch中通过 torch.nn.TripletMarginLoss 类实现,也可以直接调用 F.triplet_margin_loss 函数。 size_average 与 reduce 已经弃用。reduction有三种取值 mean , sum , none ,对应不同的返回 。 默认为 mean ,对应于一般情况下整体 的计算。

该类默认使用如下距离函数, 默认为2,对应欧式距离。

pytorch也有计算该距离的函数 torch.nn.PairwiseDistance

例子:

结果:

该loss函数与 TripletMarginLoss功能基本一致,只不过可以定制化的传入不同的距离函数。当传入的距离函数是 torch.nn.PairwiseDistance 时,两者完全一致

例子:

结果和TripletMarginLoss一致:

使用自定义的距离函数:

结果:

以上是关于loss函数之triplet loss的主要内容,如果未能解决你的问题,请参考以下文章

基于Triplet loss函数训练人脸识别深度网络(Open Face)

Triplet-Loss原理及其实现应用

Triplet-Loss原理及其实现应用

load_model 如何导入自定义的loss 函数

深度学习方法(十九):一文理解Contrastive Loss,Triplet Loss,Focal Loss

triplet loss