nn.linear()

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了nn.linear()相关的知识,希望对你有一定的参考价值。

参考技术A import torch
import torch.nn

nn.linear()是用来设置网络中的全连接层的,而在全连接层中的输入与输出都是二维张量,一般形状为[batch_size, size],与卷积层要求输入输出是4维张量不同。
用法与形参见说明如下:

in_features指的是输入的二维张量的大小,即输入的[batch_size, size]中的size。
batch_size指的是每次训练(batch)的时候样本的大小。比如CNN train的样张图片是60张,设置batch_size=15,那么iteration=4。如果想多训练几次(因为可以每次的batch不是相同的数据),那么就是epoch。
所以nn.Linear()中的输入包括有输入的图片数量,同时还有每张图片的维度。
out_features指的是输出的二维张量的大小,即输出[batch_size,size]中的size是输出的张量维度,而batch_size与输入中的一致。

参考: PyTorch的nn.Linear()详解

以上是关于nn.linear()的主要内容,如果未能解决你的问题,请参考以下文章

nn.linear()

nn.Linear 默认参数初始化方法

pytorch nn.Linear()详解

torch.nn.Linear() 理解

torch.nn.Linear() 理解

torch.nn.Linear() 理解