Pytorch的参数“batch_first”的理解

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch的参数“batch_first”的理解相关的知识,希望对你有一定的参考价值。

参考技术A

用过PyTorch的朋友大概都知道,对于不同的网络层,输入的维度虽然不同,但是通常输入的第一个维度都是batch_size,比如torch.nn.Linear的输入(batch_size,in_features),torch.nn.Conv2d的输入(batch_size, C, H, W)。而RNN的输入却是(seq_len, batch_size, input_size),batch_size位于第二维度!虽然你可以将batch_size和序列长度seq_len对换位置,此时只需要令batch_first=True。
但是 为什么RNN输入默认不是batch first=True?这是为了便于并行计算 。因为cuDNN中RNN的API就是batch_size在第二维度!进一步,为啥cuDNN要这么做呢?因为batch first意味着模型的输入(一个Tensor)在内存中存储时,先存储第一个sequence,再存储第二个... 而如果是seq_len first,模型的输入在内存中,先存储所有序列的第一个单元,然后是第二个单元... 两种区别如下图所示:

[参考资料] https://zhuanlan.zhihu.com/p/32103001

以上是关于Pytorch的参数“batch_first”的理解的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中LSTM的输出的理解,以及batch_first=True or False的输出层的区别

pytorch1.0 搭建LSTM网络

torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册

torch.nn.utils.rnn.pad_sequence()详解Pytorch入门手册

将 GRU 层从 PyTorch 转换为 TensorFlow

解码器 LSTM Pytorch 的图像字幕示例输入大小