在tensorflow中使用glstm(Group LSTM) cell构建双向rnn

Posted

技术标签:

【中文标题】在tensorflow中使用glstm(Group LSTM) cell构建双向rnn【英文标题】:Use glstm(Group LSTM) cell to build bidirectional rnn in tensorflow 【发布时间】:2017-12-27 08:43:18 【问题描述】:

我正在使用一个cnn + lstm + ctc网络(基于https://arxiv.org/pdf/1507.05717.pdf)做一个中文场景文本识别。对于大量的类(3500+),网络很难训练。听说使用 Group LSTM (https://arxiv.org/abs/1703.10722, O. Kuchaiev and B. Ginsburg "Factorization Tricks for LSTM Networks", ICLR 2017 Workshop.) 可以减少参数数量并加速训练,所以我尝试使用它在我的代码中。

我使用两层双向 lstm。这是使用 tf.contrib.rnn.LSTMCell 的原始代码

rnn_outputs, _, _ = 
tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
[tf.contrib.rnn.LSTMCell(num_units=self.num_hidden, state_is_tuple=True) for _ in range(self.num_layers)],
[tf.contrib.rnn.LSTMCell(num_units=self.num_hidden, state_is_tuple=True) for _ in range(self.num_layers)], 
self.rnn_inputs, dtype=tf.float32, sequence_length=self.rnn_seq_len, scope='BDDLSTM')

训练很慢。 100 小时后,测试集上的预测准确率仍为 39%。

现在我想使用 tf.contrib.rnn.GLSTMCell。当我用这个 GLSTMCell 替换 LSTMCell 时,就像

rnn_outputs, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
[tf.contrib.rnn.GLSTMCell(num_units=self.num_hidden, num_proj=self.num_proj, number_of_groups=4) for _ in range(self.num_layers)],
[tf.contrib.rnn.GLSTMCell(num_units=self.num_hidden, num_proj=self.num_proj, number_of_groups=4) for _ in range(self.num_layers)],
self.rnn_inputs, dtype=tf.float32, sequence_length=self.rnn_seq_len, scope='BDDLSTM')

我收到以下错误

/home/frisasz/miniconda2/envs/dl/bin/python "/media/frisasz/DATA/FSZ_Work/deep learning/IDOCR_/work/train.py"
Traceback (most recent call last):
  File "/media/frisasz/DATA/FSZ_Work/deep learning/IDOCR_/work/train.py", line 171, in <module>
    train(train_dir='/media/frisasz/Windows/40T/', val_dir='../../0000/40V/')
  File "/media/frisasz/DATA/FSZ_Work/deep learning/IDOCR_/work/train.py", line 41, in train
    FLAGS.momentum)
  File "/media/frisasz/DATA/FSZ_Work/deep learning/IDOCR_/work/model.py", line 61, in __init__
    self.logits = self.rnn_net()
  File "/media/frisasz/DATA/FSZ_Work/deep learning/IDOCR_/work/model.py", line 278, in rnn_net
    self.rnn_inputs, dtype=tf.float32, sequence_length=self.rnn_seq_len, scope='BDDLSTM')
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/rnn.py", line 220, in stack_bidirectional_dynamic_rnn
    dtype=dtype)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 375, in bidirectional_dynamic_rnn
    time_major=time_major, scope=fw_scope)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 574, in dynamic_rnn
    dtype=dtype)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 737, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2770, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2599, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2549, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 720, in _time_step
    skip_conditionals=True)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 206, in _rnn_step
    new_output, new_state = call_cell()
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py", line 708, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 441, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/rnn_cell.py", line 2054, in call
    R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False)
  File "/home/frisasz/miniconda2/envs/dl/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1005, in _linear
    "but saw %s" % (shape, shape[1]))
ValueError: linear expects shape[1] to be provided for shape (?, ?), but saw ?

Process finished with exit code 1

我不确定 GLSTMCell 是否可以简单地替换 tf.contrib.rnn.stack_bidirectional_dynamic_rnn() 中的 LSTMCell(或其他有助于构建 rnn 的函数)。我没有找到任何使用 GLSTMCell 的例子。有人知道用 GLSTMCell 构建双向 rnn 的正确方法吗?

【问题讨论】:

【参考方案1】:

我在尝试使用 bidirectional_dynamic_rnn 构建双向 GLSTM 时遇到了完全相同的错误。

在我的例子中,问题来自于 GLSTM 只能在以静态方式定义时使用:当计算图形时,你不能有未定义的形状参数(例如 batch_size)。

因此,尝试在图中定义最终会在 GLSTM 单元格中的某个点结束的所有形状,它应该可以正常工作。

【讨论】:

谢谢,我决定不使用 glstm。我用另一种方式加快了训练速度。

以上是关于在tensorflow中使用glstm(Group LSTM) cell构建双向rnn的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow-tf.group

将时间序列元素的Tensorflow数据集转换为窗口序列的数据集

如何让 Tensorflow Profiler 在 Tensorflow 2.5 中使用“tensorflow-macos”和“tensorflow-metal”工作

TensorFlow Hub 模块可以在 TensorFlow 2.0 中使用吗?

使用 GPU 无法在 tensorflow 教程中运行词嵌入示例

如何在 pytorch 和 tensorflow 中使用张量核心?