PyTorch笔记 - Seq2Seq + Attention 源码
Posted SpikeKing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch笔记 - Seq2Seq + Attention 源码相关的知识,希望对你有一定的参考价值。
Encoder:编码器,将序列建模为上下文相关的表征,输入:
- embedding_dim
- hidden_size
- embedding_table,转换,vocab id到embedding向量
Seq2SeqAttentionMechanism:Attention机制,输入,t时刻的解码器状态、encoder的全部states,只是操作,不需要学习
Decoder:Encoder是lstm_layer,Decoder是LSTMCell
import torch
import torch.nn as nn
import torch.nn.functional as F
"""
以离散符号的分类任务为例,实现基于注意力机制的seq2seq模型
"""
class Seq2SeqEncoder(nn.Module):
"""
实现基于LSTM的编码器,也可以是其他类型的,如CNN、Transformer-Encoder
"""
def __init__(self, embedding_dim, hidden_size, source_vocab_size):
super(Seq2SeqEncoder, self).__init__()
self.lstm_layer = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_size,
batch_first=True
)
self.embedding_table = nn.Embedding(source_vocab_size, embedding_dim)
def forward(self, input_ids):
input_sequence = self.embedding_table(input_ids) # 3D tensor
output_states, (final_h, final_c) = self.lstm_layer(input_sequence)
return output_states, final_h
class Seq2SeqAttentionMechanism(nn.Module):
"""
实现dot-product的Attention
"""
def __init__(self):
super(Seq2SeqAttentionMechanism, self).__init__()
def forward(self, decoder_state_t, encoder_states):
bs, source_length, hidden_size = encoder_states.shape
decoder_state_t = decoder_state_t.unsqueeze(1)
decoder_state_t = torch.tile(decoder_state_t, dims=(1, source_length, 1)) # 3D tensor
# 点乘注意力
score = torch.sum(decoder_state_t * encoder_states, dim=-1) # [bs, source_length]
attn_prob = F.softmax(score, dim=-1) # softmax
context = torch.sum(attn_prob.unsqueeze(-1)*encoder_states, 1) #
return attn_prob, context
class Seq2SeqDecoder(nn.Module):
def __init__(self, embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id):
super(Seq2SeqDecoder, self).__init__()
self.lstm_cell = torch.nn.LSTMCell(embedding_dim, hidden_size)
# num_classes 就是 target_vocab_size
self.proj_layer = nn.Linear(hidden_size*2, num_classes) # context vector 和 hidden state
self.attention_mechanism = Seq2SeqAttentionMechanism() # 注意力机制
self.num_classes = num_classes # 最后的分类层
self.embedding_table = torch.nn.Embedding(target_vocab_size, embedding_dim)
# 推理时,从start id开始,一直到end id结束,两个token
self.start_id = start_id # seq2seq任务,训练传入target seq,需要偏移
self.end_id = end_id
def forward(self, shifted_target_ids, encoder_states):
"""
训练阶段调研,teacher-force mode
"""
shifted_target = self.embedding_table(shifted_target_ids) # 2维张量变成3维
bs, target_length, embedding_dim = shifted_target.shape # 目标序列的长度
bs, source_length, hidden_size = encoder_states.shape # 原序列的长度
logits = torch.zeros(bs, target_length, self.num_classes)
probs = torch.zeros(bs, target_length, source_length)
# 每一步都需要计算上下文的向量
for t in range(target_length):
# 已知id
decoder_input_t = shifted_target[:, t, :] # [bs, embedding_dim], 第t时刻的值
# 单步执行lstm_cell
if t == 0:
h_t, c_t = self.lstm_cell(decoder_input_t)
else:
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
# 解码器的状态,和全部编码的状态
attn_prob, context = self.attention_mechanism(h_t, encoder_states)
# context vector 和 decode hidden state
decoder_output = torch.cat((context, h_t), -1)
logits[:, t, :] = self.proj_layer(decoder_output)
probs[:, t, :] = attn_prob
return probs, logits
def inference(self, encoder_states):
"""
推理阶段调用
"""
target_id = self.start_id # 起始id
h_t = None
result = []
while True:
decoder_input_t = self.embedding_table(target_id)
if h_t is None:
h_t, c_t = self.lstm_cell(decoder_input_t)
else:
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
attn_prob, context = self.attention_mechanism(h_t, encoder_states)
decoder_output = torch.cat((context, h_t), -1)
logits = self.proj_layer(decoder_output)
target_id = torch.argmax(logits, -1) # 上一个时刻的id,预测下一时刻的输入
result.append(target_id)
if torch.any(target_id == self.end_id): # 预测到end_id结束
print("stop decoding!")
break
predicted_ids = torch.stack(result, dim=0)
return predicted_ids
class Model(nn.Module):
def __init__(self, embedding_dim, hidden_size, num_classes,
source_vocab_size, target_vocab_size, start_id, end_id):
super(Model, self).__init__()
self.encoder = Seq2SeqEncoder(embedding_dim, hidden_size, source_vocab_size)
self.decoder = Seq2SeqDecoder(embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id)
def forward(self, input_sequence_ids, shifted_target_ids):
"""
训练:input_sequence_ids输入句子的ids,shifted_target_ids输出句子的ids
"""
encoder_states, final_h = self.encoder(input_sequence_ids)
probs, logits = self.decoder(shifted_target_ids, encoder_states)
return probs, logits
def infer(self):
pass
if __name__ == '__main__':
"""
单步的模拟,如果要训练,需要引入dataloader、mini-batch training
"""
source_length = 3
target_length = 4
embedding_dim = 8
hidden_size = 16
num_classes = 10
bs = 2
start_id = end_id = 0
source_vocab_size = 100
target_vocab_size = 100
# 源序列的ids
input_sequence_ids = torch.randint(source_vocab_size, size=(bs, source_length)).to(torch.int32)
target_ids = torch.randint(target_vocab_size, size=(bs, target_length))
target_ids = torch.cat((target_ids, end_id*torch.ones(bs, 1)), dim=1).to(torch.int32) # 最后一位是end_id
# shifted ids
shifted_target_ids = torch.cat((start_id*torch.ones(bs, 1), target_ids[:, 1:]), dim=1).to(torch.int32)
model = Model(embedding_dim, hidden_size, num_classes, source_vocab_size, target_vocab_size, start_id, end_id)
probs, logits = model(input_sequence_ids, shifted_target_ids)
print(probs.shape)
print(logits.shape)
以上是关于PyTorch笔记 - Seq2Seq + Attention 源码的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch笔记 - Seq2Seq + Attention 源码
PyTorch笔记 - Seq2Seq + Attention 源码