RNN+CTC 模型似乎没有正确获取数据维度

Posted

技术标签:

【中文标题】RNN+CTC 模型似乎没有正确获取数据维度【英文标题】:RNN+CTC model seems not getting the data dimension correctly 【发布时间】:2021-04-01 17:18:39 【问题描述】:

我正在训练一个带有 CTC 损失函数的简单 RNN 模型 (GRU)。以下是代码和模型摘要。我不断收到此错误,如下所示。似乎模型中的某处数据维度,可能是输入数据长度(即 [batch_size, length, mfcc_feature] 中的长度)减少了 2。我哪里弄错了?

def data_generator(batch_size, wav_files, trn_files, numcep, pinyin_dict):

    for i in range(len(wav_files)//batch_size):
        print("\n##Start Batch: ", i)
        mfcc_datasets = []
        mfcc_form_orig_len_datasets = []
        pinyin_datasets = []
        pinyin_code_orig_len_datasets = []
        begin = i * batch_size
        end = begin + batch_size
        print("begin: ", begin, "end: ", end)
        dataset_indices = list(range(begin, end))
        print("dataset_indices: ", dataset_indices)
        wav_files_subset = [wav_files[index] for index in dataset_indices]
        trn_files_subset = [trn_files[index] for index in dataset_indices]

        train_wav_max_len_batch = get_wav_max_len(wav_files_subset, numcep)
        train_pinyin_max_len_batch = get_pinyin_max_len(trn_files_subset, pinyin_dict)

        for index in dataset_indices:
        
            # transform wav to mfcc
            mfcc_form = wav_to_mfcc(wav_files[index], numcep)
            mfcc_form_expanded_padded, mfcc_form_orig_len = expand_pad_mfcc(mfcc_form, train_wav_max_len_batch)

            mfcc_datasets.append(mfcc_form_expanded_padded)
            mfcc_form_orig_len_datasets.append(mfcc_form_orig_len)

            
            # transform trn to pinyin code
            code = trn_pinyin_to_code(trn_files[index], pinyin_dict)
            pinyin_code_expanded, pinyin_code_orig_len = expand_trn(code, train_pinyin_max_len_batch)

            pinyin_datasets.append(pinyin_code_expanded)
            pinyin_code_orig_len_datasets.append(pinyin_code_orig_len)

        
        mfcc_datasets = np.array(mfcc_datasets)
        mfcc_form_orig_len_datasets = np.array(mfcc_form_orig_len_datasets)
        pinyin_datasets = np.array(pinyin_datasets)
        pinyin_code_orig_len_datasets = np.array(pinyin_code_orig_len_datasets)
    
        inputs = 'Inputs': mfcc_datasets, # size = (batch_size, length, num of features, channel)
                  'CTC_labels': pinyin_datasets, # size = (batch_size, length)
                  'CTC_input_length': mfcc_form_orig_len_datasets, 
                  'CTC_label_length': pinyin_code_orig_len_datasets,
                 
        
        outputs = 'ctc': np.zeros(mfcc_datasets.shape[0],) 
        
        print("mfcc_datasets.shape: ", mfcc_datasets.shape)
        print("mfcc_form_orig_len_datasets: ", mfcc_form_orig_len_datasets)
        print("pinyin_datasets.shape: ", pinyin_datasets.shape)
        print("pinyin_code_orig_len_datasets: ", pinyin_code_orig_len_datasets)
        print("outputs.shape: ", np.zeros(mfcc_datasets.shape[0],).shape)
        print("##End Batch: ", i)
        
        yield inputs, outputs

def ctc_lambda_func(args):
    
    y_pred, labels, input_length, label_length = args

    y_pred = y_pred[:, 2:, :]

    return tf.keras.backend.ctc_batch_cost(labels, y_pred, input_length, label_length)

def ctc_model(inputs, y_pred):
    labels = tf.keras.Input(name='CTC_labels', shape=[None], dtype='float32')
    input_length = tf.keras.Input(name='CTC_input_length', shape=[1], dtype='int64')
    label_length = tf.keras.Input(name='CTC_label_length', shape=[1], dtype='int64')
    loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
    
    ctc_model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
    print(ctc_model.summary())
    
    return ctc_model

def simple_rnn_model(input_feat_dim, output_feat_dim):

    inputs = tf.keras.Input(name='Inputs', shape=(None, input_feat_dim))
    x = GRU(name='GRU_1', units=output_feat_dim, return_sequences=True, kernel_initializer='he_normal')(inputs)
    y_pred = Activation('softmax', name='Softmax')(x)
    
    model = Model(inputs=inputs, outputs=y_pred)
    print(model.summary())
    
    ctc_model_0 = ctc_model(inputs, y_pred)
    
    return ctc_model_0

model_0 = simple_rnn_model(input_feat_dim=MFCC_FEATURES, output_feat_dim=pinyin_dict.shape[0])

以下是模型摘要: 型号:“model_8”


层(类型)输出形状参数#

输入 (InputLayer) [(None, None, 13)] 0


GRU_1 (GRU)(无,无,29)3828


Softmax(激活)(无,无,29)0

总参数:3,828 可训练参数:3,828 不可训练参数:0


无 型号:“model_9”


层(类型)输出形状参数#连接到

输入 (InputLayer) [(None, None, 13)] 0


GRU_1 (GRU)(无,无,29)3828 个输入[0][0]


Softmax(激活)(无,无,29)0 GRU_1[0][0]


CTC_labels (InputLayer) [(None, None)] 0


CTC_input_length (InputLayer) [(None, 1)] 0


CTC_label_length (InputLayer) [(None, 1)] 0


ctc (Lambda) (无, 1) 0 Softmax[0][0] CTC_labels[0][0] CTC_input_length[0][0] CTC_label_length[0][0]

总参数:3,828 可训练参数:3,828 不可训练参数:0


这是错误消息,以及用于分析的 middel 中的打印输出:

#Training Epoch:...  0

##Start Batch:  0
begin:  0 end:  2
dataset_indices:  [0, 1]
mfcc_datasets.shape:  (2, 883, 13)
mfcc_form_orig_len_datasets:  [779 883]
pinyin_datasets.shape:  (2, 34)
pinyin_code_orig_len_datasets:  [31 34]
outputs.shape:  (2,)
##End Batch:  0

##Start Batch:  1
begin:  2 end:  4
dataset_indices:  [2, 3]
mfcc_datasets.shape:  (2, 819, 13)
mfcc_form_orig_len_datasets:  [819 794]
pinyin_datasets.shape:  (2, 33)
pinyin_code_orig_len_datasets:  [33 32]
outputs.shape:  (2,)
##End Batch:  1
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-56-f01d15ca6481> in <module>
----> 1 hist = train_model(model_0)

<ipython-input-55-72d8aa1713c5> in train_model(model)
     20         print('#Training Epoch:... ', epoch)
     21         batch = data_generator(BATCH_SIZE, train_wav_files[0:8], train_trn_files[0:8], NUMCEP, pinyin_dict)
---> 22         hist = current_model.fit(batch, steps_per_epoch=BATCH_NUM, epochs=1, verbose=1)
     23 
     24     return hist

~\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

~\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

~\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    886         # Lifting succeeded, so variables are initialized and we can run the
    887         # stateless function.
--> 888         return self._stateless_fn(*args, **kwds)
    889     else:
    890       _, _, _, filtered_flat_args = \

~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2940       (graph_function,
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 2942     return graph_function._call_flat(
   2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 

~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1916         and executing_eagerly):
   1917       # No tape is watching; skip to running the function.
-> 1918       return self._build_call_outputs(self._inference_function.call(
   1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(

~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    553       with _InterpolateFunctionError(self):
    554         if cancellation_manager is None:
--> 555           outputs = execute.execute(
    556               str(self.signature.name),
    557               num_outputs=self._num_outputs,

~\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError:  sequence_length(1) <= 881
     [[node model_9/ctc/CTCLoss (defined at <ipython-input-34-5693b53d741a>:8) ]] [Op:__inference_train_function_11579]

Function call stack:
train_function

【问题讨论】:

我还想在此处添加培训代码,但我似乎无法在此处编辑自己的帖子(抱歉,第一次在这里发帖) for epoch in range(epochs): print('#Training Epoch:... ', epoch) batch = data_generator(BATCH_SIZE, train_wav_files[0:8], train_trn_files[0:8], NUMCEP, pinyin_dict) hist = current_model.fit(batch, steps_per_epoch=BATCH_NUM, epochs=1, verbose=1) 【参考方案1】:

好的...我找到了原因...因为我有这个

y_pred = y_pred[:, 2:, :]

在 ctc_lambda_loss 函数中

【讨论】:

以上是关于RNN+CTC 模型似乎没有正确获取数据维度的主要内容,如果未能解决你的问题,请参考以下文章

CTC_Loss

语音识别中的End2End模型: CTC, RNN-T与LAS

验证码识别ctc_loss

带有 CTC 层的 TensorRT

tf识别非固定长度图片ocr(数字+字母 n位长度可变)- CNN+RNN+CTC

光学字符识别多线检测