python pytorch RNN变长输入填充
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python pytorch RNN变长输入填充相关的知识,希望对你有一定的参考价值。
# x内的seq需要按长度降序排列
x_emb = self.emb(x)
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x_emb, xlen, batch_first=True)
out_pack, (ht, ct) = self.rnn(x_1_emb_p, None)
'''
返回的ht和ct就是剔除padding字符后的hidden state和cell state,都是Variable类型的。
但是返回的output是PackedSequence类型的
'''
out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=True)
#将output转化成tuple:(padded sequence, sequence lengths)。
以上是关于python pytorch RNN变长输入填充的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch 中如何处理 RNN 输入变长序列 padding
深度学习实战pytorch中如何处理RNN输入变长序列padding
Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()
在 pytorch 中通过具有线性输出层的 RNN 发送的填充批次的掩码和计算损失
PyTorch笔记 - Recurrent Neural Network(RNN) 循环神经网络
PyTorch笔记 - Recurrent Neural Network(RNN) 循环神经网络