Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解
Posted Coding With you.....
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解相关的知识,希望对你有一定的参考价值。
src_shape=["你好","生日快乐"]
tgt_shape=["hello","happy birthday to you"]
#1.embeding 先构造数据集,再构造嵌入
batch_size=2 #以两条样本为例
src_shape=torch.Tensor([2,4]).to(torch.int32) #源:第一个样本长度为2 第二个样本长度为4 你好 生日快乐
tgt_shape=torch.Tensor([1,4]).to(torch.int32) #目标:第一个样本长度为1 第二个样本长度为4 hello happy birthday to you
#获取每一个词对应的索引:查询词汇表中
#数据处理:原始文本变为数字——即每个单词在词表中的位置;然后通过padding处理得到句子的二维索引表示
src_index = [[1,2], [3,1,5,6]]
tgt_index = [[7], [13,67,78,54]]
#因为输入模型中的长度要一致,所以需要进行处理 然后加padding apdding的索引默认是0
for i in len(src_shape):
assert len(src_index[i])==src_shape[i]
# 设置序列最大长度
max_num_word = 5
len_word=8 #单词表大小
#得到源句子和目标句子
#src_seq = [F.pad(torch.tensor(L),(0, max_num_word - len(L))) for i, L in enumerate(src_index)]
#print(src_seq) #[tensor([1, 2, 0, 0, 0]), tensor([3, 1, 5, 6, 0])]
#然后我们需要把这写batchsize个样本给合成一个二维的张量 因此首先变为2维-增加维度,然后在第0维cat 这样每一个样本就有一个向量变为一个二维张量
assert len(src_index)==batch_size
src_cat_result=torch.cat([torch.unsqueeze(F.pad(torch.tensor(L),(0, src_max_num_word - len(L))),0) for i, L in enumerate(src_index)])
#print(src_cat_result) #tensor([[1, 2, 0, 0, 0], [3, 1, 5, 6, 0]])
assert len(tgt_index)==batch_size
tgt_cat_result=torch.cat([torch.unsqueeze(F.pad(torch.tensor(L),(0, tgt_max_num_word - len(L))),0) for i, L in enumerate(tgt_index)])
#嵌入:将索引转化维向量表示
model_dim=8 #论文中是512
#嵌入使用pytprch中的 词向量嵌入,里面包含所有的单词
src_embedding_table=nn.Embedding(src_len_word+1,model_dim) #+1是运因为有paddiing,这个也是文本中所涉及的单词
tgt_embedding_table=nn.Embedding(tgt_len_word+1,model_dim)
#print(src_embedding_table.weight)
"""
第一行是padding的,其余8个表示文本中单词表为8,列表示嵌入维度,因此一行表示一个词的嵌入
tensor([[-1.2555, -0.7370, 1.4169, -0.0135, 0.1515, 0.1596, 1.2738, -1.3509], padding的
[-0.2124, 2.1522, 0.7495, 0.6297, 0.5902, 2.2523, 0.2023, 0.1858],
[ 0.2194, 0.0171, 0.9608, -1.0578, 1.0288, 0.2693, -1.3096, 0.2752],
[-0.7234, 0.0349, -0.5287, 1.0626, 0.2660, -2.7938, 1.7546, 0.6249],
[ 0.9379, 1.5381, 0.8142, 1.3533, -0.0833, -0.4148, -1.6053, -1.0058],
[-1.1222, 0.9535, 0.2647, 0.4451, -2.0066, 0.8431, 0.1354, -0.2863],
[ 0.4703, 0.8993, 0.0897, -1.6634, -0.4775, 1.8519, 0.2723, 1.5300],
[ 0.8387, -0.1637, -0.9713, -1.3232, -0.7950, -0.1871, -1.2212, 0.7863],
[-0.1340, 1.7016, -1.1327, 0.9578, 0.2550, -1.3456, -0.6001, 1.5825]],
requires_grad=True)
src_embedding=src_embedding_table(src_cat_result) #调用embedding类的forward方法
#print(src_embedding)
得到了每一句话的表示:2个句子*每一句5个词*每一个词8维
tensor([[[-1.0475, 0.2883, -1.9625, 0.9092, 0.3110, 0.4850, 0.0567,
0.8287],
[-0.4096, 0.4550, -0.2292, -0.1558, 0.1996, 0.2696, 0.4571,
1.7479],
[ 0.6344, 0.5071, 1.0565, 1.0850, -1.0610, -0.0961, -1.0493,
0.2836],
[ 0.6344, 0.5071, 1.0565, 1.0850, -1.0610, -0.0961, -1.0493,
0.2836],
[ 0.6344, 0.5071, 1.0565, 1.0850, -1.0610, -0.0961, -1.0493,
0.2836]],
[[-3.2312, 0.0042, 0.2851, -0.3745, 0.8082, 0.3877, -0.3762,
0.8496],
[-1.0475, 0.2883, -1.9625, 0.9092, 0.3110, 0.4850, 0.0567,
0.8287],
[-0.0264, -0.7717, -0.3651, -1.0353, -2.3864, 0.6747, -0.4873,
1.3909],
[-0.2352, 0.2003, 1.8471, -1.1417, -0.1079, 0.5878, -0.0224,
-0.4107],
[ 0.6344, 0.5071, 1.0565, 1.0850, -1.0610, -0.0961, -1.0493,
0.2836]]], grad_fn=<EmbeddingBackward>)
#位置嵌入:使用sin cos是因为泛化能力比较好,而且还具有对称性 句子数*每一句最大词数*嵌入维度 这里也可以用两个for循环来填充每一个句子的二维位置矩阵
#我们这种方法是矩阵相乘,应用了广播机制
max_position_len=5 #位置嵌入的最大长度 和句子的长度相等吧:因为句子中每个词都有位置
pos_mat=torch.arange(max_position_len).reshape((-1,1))
i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((-1,1))/model_dim)#2i应该是表示奇数列吧,因为是从0开始索引
pe_embedding_table=torch.zeros(max_position_len,model_dim)
div_term = torch.exp(torch.arange(0, model_dim, 2) * -(math.log(10000.0) / model_dim))
pe_embedding_table[:,0::2]=torch.sin(pos_mat*div_term) #[:,0::2]表示选择所有行以及偶数列 根据公式pos_mat*i_mat会报错
pe_embedding_table[:,1::2]=torch.cos(pos_mat*div_term) #[:,1::2]表示选择所有行以及奇数列
pe_embedding=nn.Embedding(max_position_len,model_dim)
pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False) #权重用自己的,不进行学习
#注意这里借鉴词嵌入的方法,流程一样,其中参数要传入位置而不是单词
src_pos_cat=torch.cat([torch.unsqueeze(torch.arange(max(src_shape)),0) for i in src_shape])
tgt_pos_cat=torch.cat([torch.unsqueeze(torch.arange(max(tgt_shape)),0) for i in tgt_shape])
src_pe_embedding=pe_embedding(src_pos_cat)
#print(src_pe_embedding)
"""
tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00]],
[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00]]],
grad_fn=<EmbeddingBackward>)
"""
tgt_pe_embedding=pe_embedding(tgt_pos_cat)
print(tgt_pe_embedding)
tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00]],
[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00]]],
grad_fn=<EmbeddingBackward>)
#encoding中的mask:padding的标记为-inf
以上是关于Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解的主要内容,如果未能解决你的问题,请参考以下文章
Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解
Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解
Transformer代码详解(-)-从数据处理到嵌入包含数据集可构造/数据嵌入和位置嵌入的详解
广告行业中那些趣事系列4:详解从配角到C位出道的Transformer