torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册

Posted K同学啊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册相关的知识,希望对你有一定的参考价值。

函数原型

torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

函数功能

此函数返回大小为 T x B x *B x T x * 的张量,其中 T 是最长序列的长度。

参数详解

  • sequences (list[Tensor]): 可变长度序列的列表,shape=[batch_size, N],N长度不一。
  • batch_first (bool, optional) :默认batch_size在第一维度
  • padding_value (float, optional) :填充的值,默认为0。

示例

from torch.nn.utils.rnn import pad_sequence

a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()

输出

torch.Size([25, 3, 300])

以上是关于torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册的主要内容,如果未能解决你的问题,请参考以下文章