1. 导读
本节内容介绍如何使用RNN训练语言模型,计算一段文本存在的概率,并生成新的风格化文本序列。
2. 语言模型(Language model)
通过语言模型,我们可以计算某个特定句子出现的概率是多少,或者说该句子属于真实句子的概率是多少。正式点讲,一个序列模型模拟了任意特定单词序列的概率。
2.1 Language modelling with an RNN
- RNN中\\(x^{<1>}\\),\\(a^{<0>}\\)初始化为零向量。
- 通过softmax层进行预测,计算出第一个词可能是什么,其结果为\\(\\hat{y}^{<1>}\\)。这一步其实是通过softmax层计算词典里每一个词会是第一个词的概率,假设词典有n个词,那么\\(\\hat{y}^{<i>}\\)就是一个n维向量。
- RNN进入下个时间戳,将激活值\\(a^{<1>}\\)传递到第二个时间戳,这一步要计算句子中第二个词会是什么,此时的输入值为\\(y^{<1>}\\)(注意区分,不是\\(\\hat{y}^{<1>}\\)),将当前训练句子中真实的第一个词作为输入,即\\(x^{<2>}=y^{<1>}\\)。
- 接下来同样经过softmax层得到n维向量\\(\\hat{y}^{<2>}\\),每一维度所对应的数值,是已知第一词为\\(x^{<2>}\\)时第二个词是该单词(每个维度对应唯一一个单词)的概率。
- 以此类推,之后RNN中每个时间戳的输入都是上个时间戳的真实值\\(x^{<i>}=y^{<i-1>}\\),之后根据前\\((i-1)\\)个真实值\\(y^{<i>}\\)所包含的信息计算得到输出值\\(\\hat{y}^{<i>}\\),一个n维向量,n是词典的大小。
RNN训练语言模型时采用的损失函数是softmax的损失函数,即交叉熵损失函数,公式如下:
- 第一个公式中\\(y\\)是真实的数据分布(one hot 向量),\\(\\hat{y}\\)是预测的数据分布,i代表一个训练batch中的每个训练数据。
- 第二个公式将不同时刻的损失函数值累加,得到最终的损失函数。RNN在该公式的基础上进行反向传播。
2.2 计算句子概率
现在我们已经掌握了RNN训练语言模型的具体细节,那么当模型训练完成后,我们该如何应用此模型计算某一句子出现的概率?
- 假设测试句子的长度为\\(T_x\\),将该序列作为输入,重复上述操作但不进行反向传播。最后模型可以输出\\(T_x\\)个向量\\(\\hat{y}^{<i>}\\),从中提取真实单词的概率值,执行累乘操作便得到了该句子的概率值。
3. 序列生成
使用RNN训练好一个语言模型,我们不光可以用它计算某一个句子出现的概率值,也可以用它完成序列生成的任务,生成风格化文本。
- 比如用莎士比亚的作品集训练一个语言模型,之后用该模型生成一段莎士比亚风格的文本。
3.1 Sampling a sequence from a trained RNN
如前文所讲,一个序列模型模拟了任意特定单词序列的概率。我们所要做的就是,对这个概率分布进行采样,来生成一个新的单词序列。
因此,我们需要在训练好的模型基础上进行采样处理。如图所示:
- 首先我们对新句子序列中的第一个词进行采样,图中\\(a^{<0>}\\),\\(x^{<1>}\\)均初始化为零向量。在第一个时间戳中,将输出值通过softmax层得到的所有单词的概率,然后根据这个softmax的分布进行一次随机采样(使用np.random.choice()函数),得到新句子序列的第一个单词,而\\(\\hat{y}^{<1>}\\)就是该单词的one hot向量值。
- 然后进入第二个时间戳,此时需要\\(\\hat{y}^{<1>}\\)作为输入(与前文训练语言模型的输入不同,前文以真实值\\(y^{<1>}\\)作为输入),无论在上一个时间戳中采样得到什么单词,都将其作为下一个时间戳的输入。
- 接下来重复同样的操作,对第二个时间戳的softmax输出进行采样,得到新句子的序列的第二个单词,并将该单词的one hot向量作为第三个时间戳的输入。
- 以此类推,当生成EOS字符,或者时间戳增长到一定步数时,意味着语言模型已经生成了一个完整的句子。
在实现模型时,需要注意的有以下几点:
- 引入采样的是为了保证每个单词都有产生的可能,如果每次只是选取softmax对应的最大值,那么有些单词会永远生成不到,而且生成的句子会极为相似。
- 采样过程中可能会生成UNK字符,为了避免生成这种无意义字符,可以在每个时间戳中不断采样,直到得到一个非UNK字符才进入下一个时间戳。
- 模型的初始输入\\(x^{<1>}\\),可以不用初始化为零向量,而是用某一个单词的one hot向量作为赋值,这样可以生成与该单词主题相关的句子。
3.2 Sequence generation
这里展示序列生成的部分成果。
左图是使用以新闻语料训练生成的句子,右图是以莎士比亚作品集训练生成的句子。