NLP经典案例Transformer 构建语言模型

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了NLP经典案例Transformer 构建语言模型相关的知识,希望对你有一定的参考价值。

引言

什么是语言模型:

以一个符合语言规律的序列为输入,模型将利用序列间关系等特征,输出一个在所有词汇上的概率分布.这样的模型称为语言模型.

# 语言模型的训练语料一般来自于文章,对应的源文本和目标文本形如:
src1 = "I can do" tgt1 = "can do it"
src2 = "can do it", tgt2 = "do it <eos>"

语言模型能解决哪些问题:

  1. 根据语言模型的定义,可以在它的基础上完成机器翻译,文本生成等任务,因为我们通过最后输出的概率分布来预测下一个词汇是什么.

  2. 语言模型可以判断输入的序列是否为一句完整的话,因为我们可以根据输出的概率分布查看最大概率是否落在句子结束符上,来判断完整性.

  3. 语言模型本身的训练目标是预测下一个词,因为它的特征提取部分会抽象很多语言序列之间的关系,这些关系可能同样对其他语言类任务有效果.因此可以作为预训练模型进行迁移学习.

整个案例的实现可分为以下五个步骤

  • 第一步: 导入必备的工具包
  • 第二步: 导入wikiText-2数据集并作基本处理
  • 第三步: 构建用于模型输入的批次化数据
  • 第四步: 构建训练和评估函数
  • 第五步: 进行训练和评估(包括验证以及测试)

1. 导入必备的工具包

pytorch版本必须使用1.3.1, python版本使用3.6.x

pip install torch==1.3.1
# 数学计算工具包math
import math

# torch以及torch.nn, torch.nn.functional
import torch
import torch.nn as nn
import torch.nn.functional as F

# torch中经典文本数据集有关的工具包
# 具体详情参考下方torchtext介绍
import torchtext

# torchtext中的数据处理工具, get_tokenizer用于英文分词
from torchtext.data.utils import get_tokenizer

# 已经构建完成的TransformerModel
from pyitcast.transformer import TransformerModel

torchtext介绍:

  • 它是torch工具中处理NLP问题的常用数据处理包.

torchtext的重要功能:

  • 对文本数据进行处理, 比如文本语料加载, 文本迭代器构建等.
  • 包含很多经典文本语料的预加载方法. 其中包括的语料有:用于情感分析的SST和IMDB, 用于问题分类的TREC, 用于及其翻译的 WMT14, IWSLT,以及用于语言模型任务wikiText-2, WikiText103, PennTreebank.

我们这里使用wikiText-2来训练语言模型, 下面有关该数据集的相关详情:

wikiText-2数据集的体量中等, 训练集共有600篇短文, 共208万左右的词汇, 33278个不重复词汇, OoV(有多少正常英文词汇不在该数据集中的占比)为2.6%,数据集中的短文都是维基百科中对一些概念的介绍和描述.

2. 导入wikiText-2数据集并作基本处理

# 创建语料域, 语料域是存放语料的数据结构, 
# 它的四个参数代表给存放语料(或称作文本)施加的作用. 
# 分别为 tokenize,使用get_tokenizer("basic_english")获得一个分割器对象,
# 分割方式按照文本为基础英文进行分割. 
# init_token为给文本施加的起始符 <sos>给文本施加的终止符<eos>, 
# 最后一个lower为True, 存放的文本字母全部小写.
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)

# 最终获得一个Field对象.
# <torchtext.data.field.Field object at 0x7fc42a02e7f0>

# 然后使用torchtext的数据集方法导入WikiText2数据, 
# 并切分为对应训练文本, 验证文本,测试文本, 并对这些文本施加刚刚创建的语料域.
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)

# 我们可以通过examples[0].text取出文本对象进行查看.
# >>> test_txt.examples[0].text[:10]
# ['<eos>', '=', 'robert', '<unk>', '=', '<eos>', '<eos>', 'robert', '<unk>', 'is']

# 将训练集文本数据构建一个vocab对象, 
# 这样可以使用vocab对象的stoi方法统计文本共包含的不重复词汇总数.
TEXT.build_vocab(train_txt)

# 然后选择设备cuda或者cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

该案例的所有代码都将实现在一个transformer_lm.py文件中.

3. 构建用于模型输入的批次化数据

批次化过程的第一个函数batchify代码分析:

def batchify(data, bsz):
    """batchify函数用于将文本数据映射成连续数字, 并转换成指定的样式, 指定的样式可参考下图.
       它有两个输入参数, data就是我们之前得到的文本数据(train_txt, val_txt, test_txt),
       bsz是就是batch_size, 每次模型更新参数的数据量"""
    # 使用TEXT的numericalize方法将单词映射成对应的连续数字.
    data = TEXT.numericalize([data.examples[0].text])
    # >>> data
    # tensor([[   3],
    #    [  12],
    #    [3852],
    #    ...,
    #    [   6],
    #    [   3],
    #    [   3]])

    # 接着用数据词汇总数除以bsz,
    # 取整数得到一个nbatch代表需要多少次batch后能够遍历完所有数据
    nbatch = data.size(0) // bsz

    # 之后使用narrow方法对不规整的剩余数据进行删除,
    # 第一个参数是代表横轴删除还是纵轴删除, 0为横轴,1为纵轴
    # 第二个和第三个参数代表保留开始轴到结束轴的数值.类似于切片
    # 可参考下方演示示例进行更深理解.
    data = data.narrow(0, 0, nbatch * bsz)
    # >>> data
    # tensor([[   3],
    #    [  12],
    #    [3852],
    #    ...,
    #    [  78],
    #    [ 299],
    #    [  36]])
    # 后面不能形成bsz个的一组数据被删除

    # 接着我们使用view方法对data进行矩阵变换, 使其成为如下样式:
    # tensor([[    3,    25,  1849,  ...,     5,    65,    30],
    #    [   12,    66,    13,  ...,    35,  2438,  4064],
    #    [ 3852, 13667,  2962,  ...,   902,    33,    20],
    #    ...,
    #    [  154,     7,    10,  ...,     5,  1076,    78],
    #    [   25,     4,  4135,  ...,     4,    56,   299],
    #    [    6,    57,   385,  ...,  3168,   737,    36]])
    # 因为会做转置操作, 因此这个矩阵的形状是[None, bsz],
    # 如果输入是训练数据的话,形状为[104335, 20], 可以通过打印data.shape获得.
    # 也就是data的列数是等于bsz的值的.
    data = data.view(bsz, -1).t().contiguous()
    # 最后将数据分配在指定的设备上.
    return data.to(device)

batchify的样式转化图:


大写字母A,B,C … 代表句子中的每个单词.

torch.narrow演示:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> x.narrow(0, 0, 2)
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])
>>> x.narrow(1, 1, 2)
tensor([[ 2,  3],
        [ 5,  6],
        [ 8,  9]])

接下来我们将使用batchify来处理训练数据,验证数据以及测试数据.

# 训练数据的batch size
batch_size = 20

# 验证和测试数据(统称为评估数据)的batch size
eval_batch_size = 10

# 获得train_data, val_data, test_data
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)

上面的分割批次并没有进行源数据与目标数据的处理, 接下来我们将根据语言模型训练的语料规定来构建源数据与目标数据.

语言模型训练的语料规定:

  • 如果源数据为句子ABCD, ABCD代表句子中的词汇或符号, 则它的目标数据为BCDE, BCDE分别代表ABCD的下一个词汇.


如图所示,我们这里的句子序列是竖着的, 而且我们发现如果用一个批次处理完所有数据, 以训练数据为例, 每个句子长度高达104335, 这明显是不科学的, 因此我们在这里要限定每个批次中的句子长度允许的最大值bptt.

批次化过程的第二个函数get_batch代码分析:

# 令子长度允许的最大值bptt为35
bptt = 35

def get_batch(source, i):
    """用于获得每个批次合理大小的源数据和目标数据.
       参数source是通过batchify得到的train_data/val_data/test_data.
       i是具体的批次次数.
    """

    # 首先我们确定句子长度, 它将是在bptt和len(source) - 1 - i中最小值
    # 实质上, 前面的批次中都会是bptt的值, 只不过最后一个批次中, 句子长度
    # 可能不够bptt的35个, 因此会变为len(source) - 1 - i的值.
    seq_len = min(bptt, len(source) - 1 - i)

    # 语言模型训练的源数据的第i批数据将是batchify的结果的切片[i:i+seq_len]
    data = source[i:i+seq_len]

    # 根据语言模型训练的语料规定, 它的目标数据是源数据向后移动一位
    # 因为最后目标数据的切片会越界, 因此使用view(-1)来保证形状正常.
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

输入实例:

# 以测试集数据为例
source = test_data
i = 1

输出效果:

data = tensor([[   12,  1053,   355,   134,    37,     7,     4,     0,   835,  9834],
        [  635,     8,     5,     5,   421,     4,    88,     8,   573,  2511],
        [    0,    58,     8,     8,     6,   692,   544,     0,   212,     5],
        [   12,     0,   105,    26,     3,     5,     6,     0,     4,    56],
        [    3, 16074, 21254,   320,     3,   262,    16,     6,  1087,    89],
        [    3,   751,  3866,    10,    12,    31,   246,   238,    79,    49],
        [  635,   943,    78,    36,    12,   475,    66,    10,     4,   924],
        [    0,  2358,    52,     4,    12,     4,     5,     0, 19831,    21],
        [   26,    38,    54,    40,  1589,  3729,  1014,     5,     8,     4],
        [   33, 17597,    33,  1661,    15,     7,     5,     0,     4,   170],
        [  335,   268,   117,     0,     0,     4,  3144,  1557,     0,   160],
        [  106,     4,  4706,  2245,    12,  1074,    13,  2105,     5,    29],
        [    5, 16074,    10,  1087,    12,   137,   251, 13238,     8,     4],
        [  394,   746,     4,     9,    12,  6032,     4,  2190,   303, 12651],
        [    8,   616,  2107,     4,     3,     4,   425,     0,    10,   510],
        [ 1339,   112,    23,   335,     3, 22251,  1162,     9,    11,     9],
        [ 1212,   468,     6,   820,     9,     7,  1231,  4202,  2866,   382],
        [    6,    24,   104,     6,     4,     4,     7,    10,     9,   588],
        [   31,   190,     0,     0,   230,   267,     4,   273,   278,     6],
        [   34,    25,    47,    26,  1864,     6,   694,     0,  2112,     3],
        [   11,     6,    52,   798,     8,    69,    20,    31,    63,     9],
        [ 1800,    25,  2141,  2442,   117,    31,   196,  7290,     4,   298],
        [   15,   171,    15,    17,  1712,    13,   217,    59,   736,     5],
        [ 4210,   191,   142,    14,  5251,   939,    59,    38, 10055, 25132],
        [  302,    23, 11718,    11,    11,   599,   382,   317,     8,    13],
        [   16,  1564,     9,  4808,     6,     0,     6,     6,     4,     4],
        [    4,     7,    39,     7,  3934,     5,     9,     3,  8047,   557],
        [  394,     0, 10715,  3580,  8682,    31,   242,     0, 10055,   170],
        [   96,     6,   144,  3403,     4,    13,  1014,    14,     6,  2395],
        [    4,     3, 13729,    14,    40,     0,     5,    18,   676,  3267],
        [ 1031,     3,     0,   628,  1589,    22, 10916, 10969,     5, 22548],
        [    9,    12,     6,    84,    15,    49,  3144,     7,   102,    15],
        [  916,    12,     4,   203,     0,   273,   303,   333,  4318,     0],
        [    6,    12,     0,  4842,     5,    17,     4,    47,  4138,  2072],
        [   38,   237,     5,    50,    35,    27, 18530,   244,    20,     6]])

target =  tensor([  635,     8,     5,     5,   421,     4,    88,     8,   573,  2511,
            0,    58,     8,     8,     6,   692,   544,     0,   212,     5,
           12,     0,   105,    26,     3,     5,     6,     0,     4,    56,
            3, 16074, 21254,   320,     3,   262,    16,     6,  1087,    89,
            3,   751,  3866,    10,    12,    31,   246,   238,    79,    49,
          635,   943,    78,    36,    12,   475,    66,    10,     4,   924,
            0,  2358,    52,     4,    12,     4,     5,     0, 19831,    21,
           26,    38,    54,    40,  1589,  3729,  1014,     5,     8,     4,
           33, 17597,    33,  1661,    15,     7,     5,     0,     4,   170,
          335,   268,   117,     0,     0,     4,  3144,  1557,     0,   160,
          106,     4,  4706,  2245,    12,  1074,    13,  2105,     5,    29,
            5, 16074,    10,  1087,    12,   137,   251, 13238,     8,     4,
          394,   746,

以上是关于NLP经典案例Transformer 构建语言模型的主要内容,如果未能解决你的问题,请参考以下文章

NLP预训练语言模型(三):逐步解析Transformer结构

图解NLP模型发展:从RNN到Transformer

自然语言处理(NLP)基于Transformer的英文自动文摘

最新综述!NLP中的Transformer预训练模型

在线培训预告深度学习在自然语言处理领域的应用

RWKV – transformer 与 RNN 的强强联合