使用Tensorflow搭建一个简单的Seq2Seq翻译模型

Posted 月来客栈

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用Tensorflow搭建一个简单的Seq2Seq翻译模型相关的知识,希望对你有一定的参考价值。


1.背景

首先,这篇博文整理自谷歌开源的神经机器翻译项目​​Neural Machine Translation (seq2seq) Tutorial​​。如果你直接克隆这个项目按照Tutorial中的说明操作即可,那么也就不用再往下看了。

而之所以写这篇博文的目的是,虽然Seq2Seq的原理并不太难,但是在用Tensorflow实现起来的时候却不那么容易。即使谷歌开源了源码,但是对于初学者来说面对复杂的工程结构文件,依旧是一头雾水(看来好几天,源码也没弄懂)。于是笔者就根据Tutorial中的说明以及各种摸索,终于搭建出了一个简单的翻译模型。下面就来大致介绍整个模型的搭建过程,数据的预处理,以及一些重要参数的说明等等。

由于笔者本身不搞自然语言这方面的内容,只是想学习这方面技术在Tensorflow中的使用,所以对于如何。

1.1 原理

使用Tensorflow搭建一个简单的Seq2Seq翻译模型_Seq2Seq

Seq2Seq模型的主要原理如图p0105所示,先是一个Sequence通过RNN网络结构编码(左边蓝色部分)后得到 “thought vector"(图中白色矩形框,后称为中间向量),也就是说此时的中间向量包含了输入向量的所有信息,可以将其视为一个“加密”的过程。紧接着就是将中间向量再次喂给另外一个RNN网络对其进行解码(右边棕色部分),然后得到解码后的输出,可将其视为一个“解密”的过程。可能这个图太抽象了,我们再来进一步细化这个图:

使用Tensorflow搭建一个简单的Seq2Seq翻译模型_tensorflow_02

如图p0106所示,输入部分以每个单词作为RNN对应每个时刻的输入,而输出部分呢则以RNN上一时刻的输出作为下一时刻的输入,直到输出为终止符"“为止。有没有发现,这同我们之间介绍的​​用LSTM来生成唐诗​​的原理一模一样? 但是呢请注意这个问题:在训练的时候我们并不能保证解码部分每个时刻的输出就是正确的。换句话说就是,假设第一个时刻的输出为“我”,然后接着将”我“喂给下一时刻,但此时预测的结果为”你好“,然后再把”你好“喂给下一时刻预测出”明天“,最后将”明天“喂给下一此时预测出”“结束。也就是最终预测的序列为 y ^ = \\haty= y^=" 我 你好 明天 " ,虽然这样也能同正确标签 y = y= y="我 是 一个 学生 "做交叉熵然后训练网络,但这就导致训练出来的网络可能效果不好。而再翻译模型中,普遍的做法就是在训练时,解码部分每个时刻的输入就是正确标签,然后再将预测结果同正确标签做交叉熵;而在预测(inference)时再采取上一时刻的输出作为下一时刻的输入(此时也没有所谓的正确标签)这一策略,如图p0107所示:

使用Tensorflow搭建一个简单的Seq2Seq翻译模型_Seq2Seq_03

也就是说,在训练时采取的策略如图p0107所示,而预测时的策略如图p0106所示。同时,由于输入为字符序列,所以首先要进行embedding处理;其次,一般都采用多层RNN来构造网络。由此我们便得到了如图p0108所示的网络结构:

使用Tensorflow搭建一个简单的Seq2Seq翻译模型_数据预处理_04

1.2 前期准备

为了方便后面在介绍Tensorflow时一些函数的使用方法(参数的左右),在这里首先来大致介绍以下几个重要又不容易理解的变量。

从图p0108可知,整个网络模型至少需要三个​​placeholder​​​,即​​encoder_inputs,decoder_inputs,decoder_outputs​​,其分别为source input wors,target input words,target output words三个部分的输入或输出。同时,由于每个sequence的长短都是不一样的,因此在NMT这个模型中,这三个地方的变量的shape都不是固定的。有人可能会说了,将所有的句子都Padding成一个长度不久行了吗? 虽然来说理论上可以这样,但是由于sequences之间长短相差太大(至少是在NMT中),如果所有sequence都padding成一个长度,效果肯定不好,所以NMT采取的做法是:只在同一个batch中保持所有sequence的长度一样(不够的以最长的为标准再padding),也就是说同一batch保持一致,不同batch之间可以不同。

假设现在source input words中有一个batch,batch中有5个sequence,其长度分别为5,7,3,8,6,则:

​encoder_inputs.shape=[8,5]​​​; 指定了​​time_major=True​​​(不明白​​time_major​​​​戳此处见第3点​​)

​source_lengths=[5,7,3,8,6]​​; 记录每个sequence的长度

​max_source_length=8​​; 记录最长sequence的长度

1.3 数据预处理

以下以3个样本为例来主要介绍一下数据预处理部分。

汉:[[你 是 谁 ?], [你 从 哪里 来?],[你 要 到 哪里 去 ?]]

英: [[who are you ?],[where are you from ?],[where are you going ?]]


  • 首先是根据汉和英分别建立各自的单词表​​src_vocab_table,tgt_vocab_table​​,且​​UNK,SOS,EOS,PAD​​在词表中的顺序分别为[0,1,2,3]。
  • 将样本转换为各自词表中的索引(同时Padding):
    ​source_inputs=[[4,5,6,7,3,3],[4,8,9,10,7,3],[4,11,12,9,13,7]] source_lengths=[4,5,6] max_source_length = 6 target_inputs=[[1,4,5,6,7,3],[1,8,5,6,9,7],[1,8,5,6,10,7]] target_lengths=[5,6,6] max_target_length=6 target_outputs=[[4,5,6,7,2,3],[8,5,6,9,7,2],[8,5,6,10,7,2]]​注意,对于​​target_inputs,target_outputs​​来说,一定是先加上起始符和终止符再padding.

2. 编码与解码

2.1 编码encoder

encoder编码部分和写​​LSTM​​​这种网络结构几乎一样,都是通过​​dynamic_rnn​​​这个函数来完成的。目前发现唯一的区别在于此处多了一个参数​​sorce_lengths​​​,其原因是因为每个sequence的长度不一样(尽管每个baatch里padding成一样了),所以要告诉​​dynamic_rnn​​展开的时间维度。

def _build_encoder(self):
def get_encoder_cell(rnn_size):
lstm_cell = tf.nn.rnn_cell.LSTMCell(rnn_size)
return lstm_cell

encoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_encoder_cell(self.encoder_rnn_size) for _ in range(self.encoder_rnn_layer)])
self.encoder_outputs, self.encoder_final_state =
tf.nn.dynamic_rnn(cell=encoder_cell,
nputs=self.encoder_emb_inp,
sequence_length=self.source_lengths,
time_major=True,
dtype=tf.float32)


Note that sentences have different lengths to avoid wasting computation, we tell dynamic_rnn the exact source sentence lengths through source_sequence_length.


2.2 解码decoder

在上面的1.1节中我们说到,NMT的解码部分在实现的时候分为训练和推断(预测)两个部分,因此对于这两个部分也要分开来写。

在训练时通过​​TrainingHelper​​​这个函数来构造一个辅助对象,达到给每个时刻输入正确label的目的,然后通过​​BasicDecoder​​​和​​dynamic_decoder​​​进行解码;而在预测时则通过​​GreedyEmbeddingHelper​​​这个辅助对象来完成将上以时刻的输出作为下一时刻的输入这一步骤,然后同样通过​​BasicDecoder​​​和​​dynamic_decoder​​进行解码。

由于这部分代码贴出来排版看起来很乱影响阅读体验,所以就不贴了,直接参考文末贴出的代码即可。当然,接下来的就是构造损失函数等其它步骤了,参照代码中的注释即可。

3. 总结

总体来说对于使用Tensorflow来完成这个示例的难点在于一些参数的理解上,也就是1.1节中提到的几个参数。只要把这几个参数的含义弄明白了,照葫芦画瓢相对来说还是不那么困难。

​源码戳此处​

使用Tensorflow搭建一个简单的Seq2Seq翻译模型_Seq2Seq_05



以上是关于使用Tensorflow搭建一个简单的Seq2Seq翻译模型的主要内容,如果未能解决你的问题,请参考以下文章

python tensorflow 2.0 不使用 Keras 搭建简单的 LSTM 网络

简单的验证码识别之---------tensorflow环境搭建

记一次使用Tensorflow搭建神经网络模型经历

学会用tensorflow搭建简单的神经网络 2

神经网络一(用tensorflow搭建简单的神经网络并可视化)

tensorflow学习之搭建最简单的神经网络