pytorch torch.nn.CTCLoss 参数详解

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch torch.nn.CTCLoss 参数详解相关的知识,希望对你有一定的参考价值。

参考技术A CTC(Connectionist Temporal Classification),CTCLoss设计用于解决神经网络数据的label标签和网络预测数据output不能对齐的情况。比如在端到端的语音识别场景中,解析出的语音频谱数据是tensor变量,并没有标识来分割单词与单词(单字与单字),在用模型预测输出output时候也没有这种分隔符,但是数据的label(如:"它涌动的躯体如同一条鲤鱼"),是分割明显的单字,在端到端的模型中输出的元素也是每个单字的概率,所以这时候就需要CTCLoss来进行辨别。

(1)output: model预测的结果,尺寸(T,N,C)元素代表C个分类中每个元素的概率,预测一个batch中N个句子的可能性
T:句子长度,在语音识别场景中,就是音频输出的长度,是不确定的
N:batch_size
C:分类的数量,比如字典中一共有500个汉字,那么C=500
(2)y_label: 模型的正确预测值,即目标标签, 尺寸(N,S)
batch中目标语音内容的序号表示
如:[['它','涌','动','的','躯','体','如','同','一','条','鲤','鱼'], ['小','红','在','门','前','的','长','椅','上','静','静','地','读','书']-->[[123, 12, 5, 555, 43, 23, 678, 1211, 71, 12, 33, 10], [44, 222, 3, 555, 90, 8, 1, 88, 92, 72, 66, 15, 81, 50]],这里N=2, S=16
S:一个batch中最长句子的长度
(3)output_sizes: output的尺寸,(N)
N个元素都是T或者略小于T的数值
(4)label_sizes: label的尺寸(S)
y_label可以是一维的,也就是说所有的label句子的序号都连在一起,这时候label_sizes的元素代表每一个句子的长度,这时label_sizes的元素之和等于y_label的长度。

以上是关于pytorch torch.nn.CTCLoss 参数详解的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch学习笔记:PyTorch进阶训练技巧

pytorch 中的常用矩阵操作

Pytorch Note1 Pytorch介绍

pytorch_geometric + MinkowskiEngine

1. PyTorch是什么?

1. PyTorch是什么?