Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()
Posted sbj123456789
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()相关的知识,希望对你有一定的参考价值。
torch.nn.utils.rnn.pack_padded_sequence()
这里的pack
,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)
其中pack的过程为:(注意pack的形式,不是按行压,而是按列压)
(下面方框内为PackedSequence
对象,由data和batch_sizes组成)
输入的形状可以是(T×B×* )。T
是最长序列长度,B
是batch size
,*
代表任意维度(可以是0)。如果batch_first=True
的话,那么相应的 input size
就是 (B×T×*)
。
Variable
中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]
代表的是最长的序列,input[:, B-1]
保存的是最短的序列。
NOTE:
只要是维度大于等于2的input
都可以作为这个函数的参数。你可以用它来打包labels
,然后用RNN
的输出和打包后的labels
来计算loss
。通过PackedSequence
对象的.data
属性可以获取 Variable
。
参数说明:
- input (Variable) – 变长序列 被填充后的 batch
- lengths (list[int]) –
Variable
中 每个序列的长度。 - batch_first (bool, optional) – 如果是
True
,input的形状应该是B*T*size
。
返回值:
一个PackedSequence
对象。
torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence
。
上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。
返回的Varaible的值的size
是 T×B×*
, T
是最长序列的长度,B
是 batch_size,如果 batch_first=True
,那么返回值是B×T×*
。
Batch中的元素将会以它们长度的逆序排列。
参数说明:
- sequence (PackedSequence) – 将要被填充的 batch
- batch_first (bool, optional) – 如果为True,返回的数据的格式为
B×T×*
。
返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表
一个例子:
输出:
此时PackedSequence对象输入RNN后,输出RNN的还是PackedSequence对象
参考:
https://www.cnblogs.com/lindaxin/p/8052043.html
https://pytorch.org/docs/stable/nn.html?highlight=pack_padded_sequence#torch.nn.utils.rnn.pack_padded_sequence
以上是关于Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()的主要内容,如果未能解决你的问题,请参考以下文章