在 PyTorch 中准备序列到序列网络的解码器
Posted
技术标签:
【中文标题】在 PyTorch 中准备序列到序列网络的解码器【英文标题】:Prepare Decoder of a Sequence to Sequence Network in PyTorch 【发布时间】:2019-02-25 12:11:59 【问题描述】:我在 Pytorch 中使用序列到序列模型。序列到序列模型由编码器和解码器组成。
编码器转换(batch_size X input_features X num_of_one_hot_encoded_classes) -> (batch_size X input_features X hidden_size)
解码器将获取这个输入序列并将其转换为(batch_size X output_features X num_of_one_hot_encoded_classes)
举个例子-
所以在上面的例子中,我需要将 22 个输入特征转换为 10 个输出特征。在 Keras 中,可以使用 RepeatVector(10) 来完成。
一个例子 -
model.add(LSTM(256, input_shape=(22, 98)))
model.add(RepeatVector(10))
model.add(Dropout(0.3))
model.add(LSTM(256, return_sequences=True))
虽然,我不确定这是否是将输入序列转换为输出序列的正确方法。
所以,我的问题是——
将输入序列转换为的标准方法是什么 输出的。例如。从 (batch_size, 22, 98) -> (batch_size, 10, 98)?或者我应该如何准备解码器?编码器代码 sn-p(用 Pytorch 编写)-
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=1, batch_first=True)
def forward(self, input):
output, hidden = self.lstm(input)
return output, hidden
【问题讨论】:
在您的示例中,input_features
对应于“序列长度”维度。为什么要事先指定输出序列长度,而不是让解码器自然地预测“序列结束”标记?
【参考方案1】:
好吧,你必须选择,第一个是重复编码器的最后一个状态 10 次,并将其作为输入给解码器,如下所示:
import torch
input = torch.randn(64, 22, 98)
encoder = torch.nn.LSTM(98, 256, batch_first=True)
encoded, _ = encoder(input)
decoder_input = encoded[:, -1:].repeat(1, 10, 1)
decoder = torch.nn.LSTM(256, 98, batch_first=True)
decoded, _ = decoder(decoder_input)
print(decoded.shape) #torch.Size([64, 10, 98])
另一种选择是使用注意力机制,如下所示:
#assuming we have obtained the encoded sequence and declared the decoder as before
attention_calculator = torch.nn.Conv1d(256+98, 1, kernel_size=1)
hidden = (torch.zeros(1, 64, 98), torch.zeros(1, 64, 98))
outputs = []
for i in range(10):
attention_input = torch.cat([hidden[0][0][:, None, :].expand(-1, 22, -1), encoded], dim=2).permute(0, 2, 1)
attention_value = torch.nn.functional.softmax(attention_calculator(attention_input).squeeze(), dim=1)
decoder_input = (attention_value[:, :, None] * encoded).sum(dim=1, keepdim=True)
output, hidden = decoder(decoder_input, hidden)
outputs.append(output)
outputs = torch.cat(outputs, dim=1)
【讨论】:
以上是关于在 PyTorch 中准备序列到序列网络的解码器的主要内容,如果未能解决你的问题,请参考以下文章
干货对抗自编码器PyTorch手把手实战系列——PyTorch实现对抗自编码器
PyTorch 1.0 中文官方教程:序列模型和LSTM网络