序列分类的注意力机制(seq2seq tensorflow r1.1)
Posted
技术标签:
【中文标题】序列分类的注意力机制(seq2seq tensorflow r1.1)【英文标题】:Attention mechanism for sequence classification (seq2seq tensorflow r1.1) 【发布时间】:2017-09-25 04:46:32 【问题描述】:我正在尝试构建一个具有注意力机制的双向 RNN 用于序列分类。我在理解辅助函数时遇到了一些问题。我已经看到用于训练的那个需要解码器输入,但是由于我想要整个序列中的一个标签,我不知道我应该在这里给出什么输入。这是我目前构建的结构:
# Encoder LSTM cells
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden)
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden)
# Bidirectional RNN
outputs, states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell,
lstm_bw_cell, inputs=x,
sequence_length=seq_len, dtype=tf.float32)
# Concatenate forward and backward outputs
encoder_outputs = tf.concat(outputs,2)
# Decoder LSTM cell
decoder_cell = rnn.BasicLSTMCell(n_hidden)
# Attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(n_hidden, encoder_outputs)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell,
attention_mechanism, attention_size=n_hidden)
name="attention_init")
# Initial attention
attn_zero = attn_cell.zero_state(batch_size=tf.shape(x)[0], dtype=tf.float32)
init_state = attn_zero.clone(cell_state=states[0])
# Helper function
helper = tf.contrib.seq2seq.TrainingHelper(inputs = ???)
# Decoding
my_decoder = tf.contrib.seq2seq.BasicDecoder(cell=attn_cell,
helper=helper,
initial_state=init_state)
decoder_outputs, decoder_states = tf.contrib.seq2seq.dynamic_decode(my_decoder)
我的输入是一个序列 [batch_size,sequence_length,n_features],我的输出是一个包含 N 个可能类 [batch_size,n_classes] 的单个向量。
你知道我在这里遗漏了什么,或者是否可以使用 seq2seq 进行序列分类?
【问题讨论】:
【参考方案1】:根据定义,Seq2Seq 模型不适合这样的任务。顾名思义,它将输入序列(句子中的单词)转换为标签序列(单词的词性)。在您的情况下,您正在寻找每个样本的单个标签,而不是它们的序列。
幸运的是,您已经拥有了所需的一切,因为您只需要编码器(RNN)的输出或状态。
使用它创建分类器的最简单方法是使用 RNN 的最终状态。在此之上添加一个形状为 [n_hidden, n_classes] 的全连接层。在此您可以训练一个 softmax 层和预测最终类别的损失。
原则上,这不包括注意力机制。但是,如果你想包含一个,可以通过一个学习向量对 RNN 的每个输出进行加权,然后求和来完成。但是,这并不能保证改善结果。如果我没记错的话,https://arxiv.org/pdf/1606.02601.pdf 实现了这种类型的注意机制以供进一步参考。
【讨论】:
我不同意 seq2seq 不适合分类。这里用于分类任务:andriymulyar.com/blog/bert-document-classification以上是关于序列分类的注意力机制(seq2seq tensorflow r1.1)的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch-16 seq2seq translation 使用序列到序列的网络和注意机制进行翻译
MXNet的机器翻译实践《编码器-解码器(seq2seq)和注意力机制》