说话人识别损失函数的PyTorch实现与代码解读
Posted DEDSEC_Roger
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了说话人识别损失函数的PyTorch实现与代码解读相关的知识,希望对你有一定的参考价值。
概述
- 说话人识别中的损失函数分为基于多类别分类的损失函数,和端到端的损失函数(也叫基于度量学习的损失函数),关于这些损失函数的理论部分,可参考说话人识别中的损失函数
- 本文主要关注这些损失函数的实现,此外,文章说话人识别中的损失函数中,没有详细介绍基于多类别分类的损失函数,因此本文会顺便补足这一点
- 本文持续更新
Softmax Loss
- 先看Softmax Loss,完整的叫法是Cross-entropy Loss with Softmax,主要由三部分组成
-
Fully Connected:将当前样本的嵌入码(embedding),变换成长度为类别数的向量(通常称为Logit),公式如下
y = W x + b y=Wx+b y=Wx+b
其中- x是特征向量,长度为 e m b e d - d i m embed\\text-dim embed-dim
- W是权重矩阵,维度为 [ n - c l a s s e s , e m b e d - d i m ] [n\\text-classes,embed\\text-dim] [n-classes,embed-dim], n - c l a s s e s n\\text-classes n-classes为类别数
- b是偏置向量,长度为 n - c l a s s e s n\\text-classes n-classes
- Logit中的每一个值,对应W的每一行与x逐项相乘再相加,然后与b中的对应项再相加
-
Softmax:将Logit变换成多类别概率分布Probability,不改变向量长度,公式如下(取 N = n - c l a s s e s − 1 N=n\\text-classes-1 N=n-classes−1)
y i = e x i ∑ i = 0 N e x i y_i=\\frace^x_i\\sum_i=0^Ne^x_i yi=∑i=0Nexiexi
- 本质上是max函数的软化版本,将不可导的max函数变得可导
- 因此需要像max函数那样,具有最大值主导的特点,上图中
s o f t m a x ( [ 3 , 1 , − 3 ] ) = [ 0.88 , 0.12 , 0 ] softmax([3,1,-3])=[0.88,0.12,0] softmax([3,1,−3])=[0.88,0.12,0] - 又因为输出是多类别概率分布,因此Probability的每一项相加等于1
∑ i = 0 N y i = 1 \\sum_i=0^Ny_i=1 i=0∑Nyi=1 - 但是当Logit的值都比较小时,比如:
[
0
,
1
]
[0,1]
[0,1],最大值主导的效果不明显
s o f t m a x ( [ 0.1 , 0.3 , 0.5 , 0.7 , 0.9 ] ) = [ 0.1289 , 0.1574 , 0.1922 , 0.2348 , 0.2868 ] softmax([0.1,0.3,0.5,0.7,0.9])=[0.1289, 0.1574, 0.1922, 0.2348, 0.2868] softmax([0.1,0.3,0.5,0.7,0.9])=[0.1289,0.1574,0.1922,0.2348,0.2868]
-
Cross-entropy(交叉熵):将Ground Truth(基本事实)的One-hot Vector(记为 P P P)与Probability(记为 Q Q Q)计算相似度,输出是标量。交叉熵的值越小,Probability与One-hot Vector越相似,公式如下
L C E ( P , Q ) = − ∑ i = 0 N p i log ( q i ) L_CE(P,Q)=-\\sum_i=0^N p_i \\log(q_i) LCE(P,Q)=−i=0∑Npilog(qi)- One-hot Vector的长度与Probability一致,即等于类别数 N N N,形式为 [ 0 , 0 , . . . , 1 , . . . , 0 ] [0,0,...,1,...,0] [0,0,...,1,...,0],即GT是哪个类,哪个类对应的下标就为1
- 设One-hot Vector值为1的下标为
j
j
j,上式可简化为
L S o f t m a x ( P , Q ) = − log ( q j ) = − log ( e x j ∑ i = 0 N e x i ) L_Softmax(P,Q)=-\\log(q_j)=-\\log(\\frace^x_j\\sum_i=0^Ne^x_i) LSoftmax(P,Q)=−log(qj)=−log(∑i=0Nexiexj)
-
- 在上述的过程中,如果用tensor.scatter_来实现One-hot Vector是比较难懂的,完整PyTorch代码如下
import torch import torch.nn.functional as F import torch.nn as nn embed_dim = 5 num_class = 10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]) x.unsqueeze_(0) # 模拟batch-size,就地在dim = 0插入维度,此时x的维度为[1,5] x = x.expand(2, embed_dim) # 直接堆叠x,使batch-size = 2,此时x的维度为[2,5] x = x.float().to(device) # label是长度为batch-size的向量,每个值是GT的下标,维度为[2] label = torch.tensor([0, 5]) label = label.long().to(device) weight = nn.Parameter(torch.FloatTensor(num_class, embed_dim)).to(device) nn.init.xavier_uniform_(weight) # 初始化权重矩阵 logit = F.linear(x, weight) # 取消偏置向量 probability = F.softmax(logit, dim=1) # 维度为[2,10] # one_hot的数据类型与设备要和x相同,维度和Probability相同[2,10] one_hot = x.new_zeros(probability.size()) # 根据label,就地得到one_hot,步骤如下 # scatter_函数:Tensor.scatter_(dim, index, src, reduce=None) # 先把label的维度变为[2,1],然后根据label的dim = 1(参数中的src)上的值 # 作为one_hot的dim = 1(参数中的dim)上的下标,并将下标对应的值设置为1 # 由于label的dim = 1上的值只有一个,所以是One-hot,如果label维度为[2,2],则为Two-hot # 如果label维度为[2,k],则为K-hot one_hot.scatter_(1, label.view(-1, 1).long(), 1) # 等价于 # one_hot = F.one_hot(label, num_class).float().to(device) # 但是F.one_hot只能构造One-hot,Tensor.scatter_可以构造K-hot # 对batch中每个样本计算loss,并求均值 loss = 0 for P, Q in zip(one_hot, probability): loss += torch.log((P * Q).sum()) loss /= -one_hot.size()[0] # 等价于 # loss = F.cross_entropy(logit, label)
- 上述PyTorch代码要看懂,是之后魔改Softmax Loss的基础
AAM-Softmax(ArcFace)
-
AAM-Softmax(Additive Angular Margin Loss,也叫ArcFace)出自人脸识别,是说话人识别挑战VoxSRC近年冠军方案的基础损失函数,是基于Softmax Loss进行改进而来的。步骤如下
-
取消偏置向量,根据上文,Logit中的每一个值,对应W的每一行 w i w_i wi与x逐项相乘再相加,即 y i = w i x y_i=w_ix yi=wix
-
把 w i w_i wi和 x x x都单位化
w i ′ = w i ∣ ∣ w i ∣ ∣ , x ′ = x ∣ ∣ x ∣ ∣ w'_i=\\fracw_i||w_i||,x'=\\fracx||x|| wi′=∣∣wi∣∣wi,x′=∣∣x∣∣x -
计算Logit,此时Logit中的每一个值如下,即 w
以上是关于说话人识别损失函数的PyTorch实现与代码解读的主要内容,如果未能解决你的问题,请参考以下文章
深度学习100例 | 第4例:水果识别 - PyTorch实现
深度学习100例 | 第4例:水果识别 - PyTorch实现
深度学习100例 | 第3天:交通标志识别 - PyTorch实现
-