了解 PyTorch LSTM 的输入形状

Posted

技术标签:

【中文标题】了解 PyTorch LSTM 的输入形状【英文标题】:Understanding input shape to PyTorch LSTM 【发布时间】:2020-08-21 06:16:26 【问题描述】:

这似乎是 PyTorch 中关于 LSTM 的最常见问题之一,但我仍然无法弄清楚 PyTorch LSTM 的输入形状应该是什么。

即使在关注了几个帖子(1、2、3)并尝试了解决方案之后,它似乎也不起作用。

背景:我已经编码了一批大小为 12 的文本序列(可变长度),并使用 pad_packed_sequence 功能填充和打包了这些序列。每个序列的MAX_LEN 为 384,序列中的每个标记(或单词)的维度为 768。因此,我的批量张量可能具有以下形状之一:[12, 384, 768][384, 12, 768]

批处理将是我对 PyTorch rnn 模块(此处为 lstm)的输入。

根据 LSTMs 的 PyTorch 文档,它的输入维度是 (seq_len, batch, input_size),我理解如下。seq_len - 每个输入流中的时间步数(特征向量长度)。batch - 每批输入序列的大小。input_size - 每个输入标记或时间步长的维度。

lstm = nn.LSTM(input_size=?, hidden_size=?, batch_first=True)

input_sizehidden_size 的确切值应该是什么?

【问题讨论】:

【参考方案1】:

图像传到CNN层和lstm层,特征图形状变化是这样的

BCHW->BCHW(BxCx1xW), CNN 的输出形状的高度应为 1。 然后挤压高度的暗淡。 BCHW->BCW 在 rnn 中,形状名称更改,[batch ,seqlen,input_size],在图像中,[batch,width,channel], **BCW->BWC,**这是 LSTM 层的 batch_first 张量(如 pytorch)。 最后: BWC 是 [batch,seqlen,channel]。

【讨论】:

【参考方案2】:

您已经解释了输入的结构,但您尚未在输入维度和 LSTM 的预期输入维度之间建立联系。

让我们分解您的输入(为维度分配名称):

batch_size:12 seq_len: 384 input_size/num_features:768

这意味着 LSTM 的input_size 需要为 768。

hidden_size 不取决于您的输入,而是取决于 LSTM 应该创建多少特征,然后将其用于隐藏状态和输出,因为这是最后一个隐藏状态。您必须决定要为 LSTM 使用多少功能。

最后,对于输入形状,设置batch_first=True 要求输入的形状为[batch_size, seq_len, input_size],在您的情况下为[12, 384, 768]

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])

【讨论】:

官方文档解释“输入形状(seq_len,batch,input_size)”。所以,我认为输入形状应该是 (384, 12, 768)。 pytorch.org/docs/stable/generated/torch.nn.LSTM.html 对不起,我意识到batch_first 选项。此参数替换 seq_len 和 batch。所以,你是对的。

以上是关于了解 PyTorch LSTM 的输入形状的主要内容,如果未能解决你的问题,请参考以下文章

了解 TensorFlow LSTM 输入形状

PyTorch 中的双向 LSTM 输出问题

理解 LSTM 中的输入和输出形状 | tf.keras.layers.LSTM(以及对于return_sequences的解释)

如何正确地为 PyTorch 中的嵌入、LSTM 和线性层提供输入?

了解 LSTM 的输出形状

在 Pytorch 中实现有状态 LSTM/ConvLSTM 的最佳方式是啥?