批处理如何在 pytorch 的 seq2seq 模型中工作?
Posted
技术标签:
【中文标题】批处理如何在 pytorch 的 seq2seq 模型中工作?【英文标题】:How does batching work in a seq2seq model in pytorch? 【发布时间】:2018-08-23 07:38:46 【问题描述】:我正在尝试在 Pytorch 中实现一个 seq2seq 模型,但批处理有一些问题。 例如我有一批数据,其维度是
[batch_size, sequence_lengths, encoding_dimension]
批次中每个示例的序列长度不同。
现在,我设法通过将批次中的每个元素填充到最长序列的长度来完成编码部分。
这样,如果我给我的网络一个与所说形状相同的批次作为输入,我会得到以下输出:
输出,形状为
[batch_size, sequence_lengths, hidden_layer_dimension]
隐藏状态,形如
[batch_size, hidden_layer_dimension]
细胞状态,形状为
[batch_size, hidden_layer_dimension]
现在,从输出,我为每个序列取最后一个相关元素,即沿sequence_lengths
维度的元素,对应于序列的最后一个非填充元素。因此我得到的最终输出是[batch_size, hidden_layer_dimension]
。
但现在我遇到了从这个向量解码它的问题。如何处理同一批次中不同长度序列的解码?我试图用谷歌搜索它并找到this,但他们似乎没有解决这个问题。我想为整个批次逐个元素地做,但是我有传递初始隐藏状态的问题,因为来自编码器的那些将是形状[batch_size, hidden_layer_dimension]
,而来自解码器的那些将是形状[1, hidden_layer_dimension]
.
我错过了什么吗?感谢您的帮助!
【问题讨论】:
【参考方案1】:你没有错过任何东西。我可以帮助你,因为我已经使用 PyTorch 处理了几个序列到序列的应用程序。下面我给你一个简单的例子。
class Seq2Seq(nn.Module):
"""A Seq2seq network trained on predicting the next query."""
def __init__(self, dictionary, embedding_index, args):
super(Seq2Seq, self).__init__()
self.config = args
self.num_directions = 2 if self.config.bidirection else 1
self.embedding = EmbeddingLayer(len(dictionary), self.config)
self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize)
self.encoder = Encoder(self.config.emsize, self.config.nhid_enc, self.config.bidirection, self.config)
self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary),
self.config)
@staticmethod
def compute_decoding_loss(logits, target, seq_idx, length):
losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze()
mask = helper.mask(length, seq_idx) # mask: batch x 1
losses = losses * mask.float()
num_non_zero_elem = torch.nonzero(mask.data).size()
if not num_non_zero_elem:
return losses.sum(), 0 if not num_non_zero_elem else losses.sum(), num_non_zero_elem[0]
def forward(self, q1_var, q1_len, q2_var, q2_len):
# encode the query
embedded_q1 = self.embedding(q1_var)
encoded_q1, hidden = self.encoder(embedded_q1, q1_len)
if self.config.bidirection:
if self.config.model == 'LSTM':
h_t, c_t = hidden[0][-2:], hidden[1][-2:]
decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat(
(c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2)
else:
h_t = hidden[0][-2:]
decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2)
else:
if self.config.model == 'LSTM':
decoder_hidden = hidden[0][-1], hidden[1][-1]
else:
decoder_hidden = hidden[-1]
decoding_loss, total_local_decoding_loss_element = 0, 0
for idx in range(q2_var.size(1) - 1):
input_variable = q2_var[:, idx]
embedded_decoder_input = self.embedding(input_variable).unsqueeze(1)
decoder_output, decoder_hidden = self.decoder(embedded_decoder_input, decoder_hidden)
local_loss, num_local_loss = self.compute_decoding_loss(decoder_output, q2_var[:, idx + 1], idx, q2_len)
decoding_loss += local_loss
total_local_decoding_loss_element += num_local_loss
if total_local_decoding_loss_element > 0:
decoding_loss = decoding_loss / total_local_decoding_loss_element
return decoding_loss
你可以看到完整的源代码here。这个应用程序是关于在给定当前网络搜索查询的情况下预测用户的下一个网络搜索查询。
您的问题的回答者:
如何处理同一批次中不同长度序列的解码?
您有填充序列,因此您可以认为所有序列的长度相同。但是当您计算损失时,您需要使用 masking 忽略那些填充项的损失。
我在上面的例子中使用了masking 技术来实现同样的效果。
此外,您是绝对正确的:您需要逐个元素地解码小批量。初始解码器状态[batch_size, hidden_layer_dimension]
也很好。你只需要在维度 0 处解压它,使其成为[1, batch_size, hidden_layer_dimension]
。
请注意,您不需要循环遍历批处理中的每个示例,您可以一次执行整个批处理,但您需要循环遍历序列的元素。
【讨论】:
感谢您的回复。那么,在训练时,我让解码器从一批编码输入中预测一批输出,其中预测的序列最大长度是目标批次中最长的元素之一?希望它学会如何自己填充输出?但是在推理时,当我不再使用批次时,我可以使用相同的网络结构,只需使用批次维度 1 初始化隐藏权重吗?另外,你在解码器状态添加的维度是sequence_length维度,对吧?【参考方案2】:可能有点老,但我目前面临类似的问题:
我将数据批量输入到 seq2seq 模型中,但是当解码器从
decoder_input = torch.tensor([[self.SOS_token] * self.batch_size], device=device)
# decoder_input.size() = ([1,2])
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
topv, topi = decoder_output.data.topk(1)
decoder_input = topi.squeeze().detach()
# decoder_input.size() = ([1])
outputs[di] = decoder_output
if decoder_input.item() == self.EOS_token:
break
所以如果例如我使用 batch_size = 2 我有 2 个 SOS_tokens,但后来我只得到 1 个单词作为预测,这是模型无法计算的,因为它需要不同的大小。我应该像 SOS_token 一样将它相乘吗? 最好的
【讨论】:
以上是关于批处理如何在 pytorch 的 seq2seq 模型中工作?的主要内容,如果未能解决你的问题,请参考以下文章
文本摘要Pytorch之Seq2seq: attention
PyTorch笔记 - Seq2Seq + Attention 算法