Tensorflow CTC 损失:ctc_merge_repeated 参数

Posted

技术标签:

【中文标题】Tensorflow CTC 损失:ctc_merge_repeated 参数【英文标题】:Tensorflow CTC loss: ctc_merge_repeated parameter 【发布时间】:2018-01-16 00:29:24 【问题描述】:

我正在使用 Tensorflow 1.0 及其 CTC 损失 [1]。 训练时,我有时会收到“找不到有效路径”。警告(这会损害学习)。 不是因为其他 Tensorflow 用户有时报告的高学习率。

经过一番分析,我找到了导致此警告的模式:

将输入序列输入到长度为 seqLen 的 ctc_loss 使用 labelLen 个字符输入标签 标签中有 numRepeatedChars 个重复的字符,我将“ab”计为 0,“aa”计为 1,“aaa”计为 2,依此类推 警告发生,当:seqLen - labelLen numRepeatedChars

三个例子:

Ex.1: label="abb", len(label)=3, len(inputSequence)=3 => (3-3=0) 警告 Ex.2: label="abb", len(label)=3, len(inputSequence)=4 => (4-3=1) 没有警告 Ex.3: label="bbb", len(label)=3, len(inputSequence)=4 => (4-3=1) 警告

当我现在设置 ctc_loss 参数 ctc_merge_repeated=False 时,警告就会消失。

三个问题:

Q1:为什么出现重复字符时会出现警告?我想,只要输入序列不短于目标标注,就没有问题。而当重复的字符被合并到标签中时,它会变得更短,因此输入序列不短的条件仍然成立。 Q2:为什么 ctc_loss 在其默认设置中会产生此警告?重复字符在使用 CTC 的领域中很常见,例如手写文本识别 (HTR) Q3:做HTR时应该使用哪些设置?当然标签可以有重复的字符。因此 ctc_merge_repeated=False 是有意义的。有什么建议?

重现警告的 Python 程序:

import tensorflow as tf
import numpy as np

def createGraph():
    tinputs=tf.placeholder(tf.float32, [100, 1, 65]) # max 100 time steps, 1 batch element, 64+1 classes
    tlabels=tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2])) # labels
    tseqLen=tf.placeholder(tf.int32, [None]) # list of sequence length in batch
    tloss=tf.reduce_mean(tf.nn.ctc_loss(labels=tlabels, inputs=tinputs, sequence_length=tseqLen, ctc_merge_repeated=True)) # ctc loss
    return (tinputs, tlabels, tseqLen, tloss)

def getNextBatch(nc): # next batch with given number of chars in label
    indices=[[0,i] for i in range(nc)]
    values=[i%65 for i in range(nc)]
    values[0]=0
    values[1]=0 # TODO: (un)comment this to trigger warning
    shape=[1, nc]
    labels=tf.SparseTensorValue(indices, values, shape)
    seqLen=[nc]
    inputs=np.random.rand(100, 1, 65)
    return (labels, inputs, seqLen) 


(tinputs, tlabels, tseqLen, tloss)=createGraph()

sess=tf.Session()
sess.run(tf.global_variables_initializer())

nc=3 # number of chars in label
print('next batch with 1 element has label len='+str(nc))
(labels, inputs, seqLen)=getNextBatch(nc)
res=sess.run([tloss],  tlabels: labels, tinputs:inputs, tseqLen:seqLen  )

这是警告来自的 C++ Tensorflow 代码 [2]:

// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) 
    LOG(WARNING) << "No valid path found.";
    dy_b = y;
    return;

[1]https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/nn/ctc_loss

[2]https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/ctc/ctc_loss_calculator.cc

【问题讨论】:

【参考方案1】:

好的,明白了,这不是错误,这正是 CTC 的工作原理:让我们举一个发生警告的示例:输入序列的长度为 2,标签为“aa”(长度也为 2)。

现在产生“aa”的最短路径是 a->blank->a(长度为 3)。 但是对于标签“ab”,最短路径是 a->b(长度为 2)。 这说明了为什么对于像“aa”这样的重复标签,输入序列必须更长。它只是通过插入空白在 CTC 中对重复标签进行编码的方式。

因此,在固定输入大小时,标签重复会减少允许标签的最大长度。

【讨论】:

以上是关于Tensorflow CTC 损失:ctc_merge_repeated 参数的主要内容,如果未能解决你的问题,请参考以下文章

利用CRNN来识别图片中的文字(二)tensorflow中ctc有关函数详解

RNN+CTC 模型似乎没有正确获取数据维度

(CRNN OCR) 训练时出错!无效参数:sequence_length(0) <= 18 节点 ctc/CTCLoss

使用keras框架cnn+ctc_loss识别不定长字符图片操作

tensorflow 打印的损失是批量/样本损失还是运行平均损失?

Tensorflow:损失变成'NaN'