torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册
Posted K同学啊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册相关的知识,希望对你有一定的参考价值。
函数原型
torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)
函数功能
此函数返回大小为 T x B x *
或 B x T x *
的张量,其中 T 是最长序列的长度。
参数详解
sequences (list[Tensor])
: 可变长度序列的列表,shape=[batch_size, N],N长度不一。batch_first (bool, optional)
:默认batch_size在第一维度padding_value (float, optional)
:填充的值,默认为0。
示例
from torch.nn.utils.rnn import pad_sequence
a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()
输出
torch.Size([25, 3, 300])
以上是关于torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册的主要内容,如果未能解决你的问题,请参考以下文章