loss函数之triplet loss
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了loss函数之triplet loss相关的知识,希望对你有一定的参考价值。
参考技术A 不同于交叉熵损失仅仅考虑样本与类别标签之间误差,triplet loss关注样本与其他样本之间距离。来自论文 Learning local feature descriptors with triplets and shallow convolutional neural networks对于包含 个样本的batch数据 。 第 个样本对应的 ,如下:
其中, , , ,分别代表锚点,正例(与锚点同类)和负例(与锚点不同类)。距离函数 , 用于度量锚点与正例负例之间的距离。 是人为设置的常数。最小化损失函数,使得锚点与正例的距离越小,与负例的距离越大。
由以上公式可知,
(1) 当 ,即 , 该样本对应的 为0。
此时,锚点和负例的距离大于锚点和正例的距离,并且差值大于 。 对于这样的锚点被认为是易分类样本,直接忽略其带来的误差,从而加速计算。
(2) 当 , 该样本对应的 为 , 分为两种情况:
pytorch中通过 torch.nn.TripletMarginLoss 类实现,也可以直接调用 F.triplet_margin_loss 函数。 size_average 与 reduce 已经弃用。reduction有三种取值 mean , sum , none ,对应不同的返回 。 默认为 mean ,对应于一般情况下整体 的计算。
该类默认使用如下距离函数, 默认为2,对应欧式距离。
pytorch也有计算该距离的函数 torch.nn.PairwiseDistance
例子:
结果:
该loss函数与 TripletMarginLoss功能基本一致,只不过可以定制化的传入不同的距离函数。当传入的距离函数是 torch.nn.PairwiseDistance 时,两者完全一致
例子:
结果和TripletMarginLoss一致:
使用自定义的距离函数:
结果:
以上是关于loss函数之triplet loss的主要内容,如果未能解决你的问题,请参考以下文章
基于Triplet loss函数训练人脸识别深度网络(Open Face)