(CRNN OCR) 训练时出错!无效参数:sequence_length(0) <= 18 节点 ctc/CTCLoss

Posted

技术标签:

【中文标题】(CRNN OCR) 训练时出错!无效参数:sequence_length(0) <= 18 节点 ctc/CTCLoss【英文标题】:(CRNN OCR) Error while training! Invalid Argument: sequence_length(0) <= 18 node ctc/CTCLoss 【发布时间】:2020-07-31 18:54:01 【问题描述】:

我在 OCR 上使用 CRNN(CNN + RNN + CTC 损失)作为我的模型。我正在使用 Tensorflow Keras

这是我的代码 [来自 CTC Loss]:

labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')


def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args

    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([outputs, labels, input_length, label_length])

#model to be used at training time
model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)

model.compile(loss='ctc': lambda y_true, y_pred: y_pred, optimizer = 'adam')

filepath="best_model.hdf5"
checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
callbacks_list = [checkpoint]

training_img = np.array(training_img)
train_input_length = np.array(train_input_length)
train_label_length = np.array(train_label_length)

valid_img = np.array(valid_img)
valid_input_length = np.array(valid_input_length)
valid_label_length = np.array(valid_label_length)

训练时出现错误:

batch_size = 256
epochs = 10
model.fit(x=[training_img, train_padded_txt, train_input_length, train_label_length], y=np.zeros(len(training_img)), 
          batch_size=batch_size, epochs = epochs, 
          validation_data = ([valid_img, valid_padded_txt, valid_input_length, valid_label_length], [np.zeros(len(valid_img))]), 
          verbose = 1, callbacks = callbacks_list)

错误结果

Train on 448 samples, validate on 49 samples
Epoch 1/10
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-15-1322212af569> in <module>()
      4           batch_size=batch_size, epochs = epochs,
      5           validation_data = ([valid_img, valid_padded_txt, valid_input_length, valid_label_length], [np.zeros(len(valid_img))]),
----> 6           verbose = 1, callbacks = callbacks_list)

7 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  sequence_length(0) <= 18
     [[node ctc/CTCLoss (defined at /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_12073]

Function call stack:
keras_scratch_graph

我的 CRNN 架构受到 VGG-16 的启发,我使用了 13 个卷积层和 3 个双向 LSTM 层。我正在使用 CTC Loss,然后出现错误。 我的数据是 1000 个文本图像,包含 4-8 个单词(700 个用于训练和验证,300 个用于测试)

如果您想查看我的代码:这是我使用 google colab 的代码。 https://colab.research.google.com/drive/1nMRNUsLDNrpgeTxPFQ4mhobnFdpbmwUx

【问题讨论】:

我不是 OCR 方面的专家,但函数调用堆栈错误有点熟悉。每当我的 GPU 用完它的 RAM 时,我都会遇到它。尝试将 GPU 更改为 TPU,看看它是否可以解决您的问题.... 【参考方案1】:

我修复了这个错误。就是因为这个!

之前:

  # split the 700 data into validation and training dataset as 10% and 90% respectively
        if i%10 == 0:     
            valid_orig_txt.append(txt)   
            valid_label_length.append(len(txt))
            valid_input_length.append(31)
            valid_img.append(img)
            valid_txt.append(encode_to_labels(txt))
        else:
            orig_txt.append(txt)   
            train_label_length.append(len(txt))
            train_input_length.append(31)
            training_img.append(img)
            training_txt.append(encode_to_labels(txt)) 

之后:

  # split the 700 data into validation and training dataset as 10% and 90% respectively
        if i%10 == 0:     
            valid_orig_txt.append(txt)   
            valid_label_length.append(len(txt))
            valid_input_length.append(18)
            valid_img.append(img)
            valid_txt.append(encode_to_labels(txt))
        else:
            orig_txt.append(txt)   
            train_label_length.append(len(txt))
            train_input_length.append(18)
            training_img.append(img)
            training_txt.append(encode_to_labels(txt)) 

【讨论】:

以上是关于(CRNN OCR) 训练时出错!无效参数:sequence_length(0) <= 18 节点 ctc/CTCLoss的主要内容,如果未能解决你的问题,请参考以下文章

[深度学习][OCR][原创]CRNN_Chinese_Characters_Rec训练360w数据集提示keyerror错误解决方法

百度飞桨(PaddlePaddle)

OCR-CRNN (CNN+CTC)文字识别,实践上手

使用谷歌colab训练crnn模型

利用CRNN来识别图片中的文字(一)数据预处理

在 pycharm 上加载经过训练的 Tensorflow 保存模型时出错。 ValueError:int() 的无效文字,基数为 10:'class_name'