关于对PyTorch中nn.Linear的官方API文档解读
Posted 小魏同学
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了关于对PyTorch中nn.Linear的官方API文档解读相关的知识,希望对你有一定的参考价值。
1.1 作用
pytorch中的全连接层,对于输入的向量x进行线性变换y=xAT+b
1.2 参数
in_features – 输入的向量的size
out_features – 最终返回的输出向量的size
bias – If set to False
, the layer will not learn an additive bias. Default: True
1.3 形状
-
Input: (N,∗,Hin) where * means any number of additional dimensions and Hin=in_features
-
Output: (N,∗,Hout) where all but the last dimension are the same shape as the input and Hout=out_features .
from torch import nn import torch linear=nn.Linear(in_features=64*3,out_features=1) a=torch.rand(3,7,64*3) print(a.shape) print(linear.weight.shape) b=linear(a) print(b.shape)
/Users/weihaoyang/opt/anaconda3/envs/nlp_chat/bin/python /Users/weihaoyang/PycharmProjects/pytorch练习/test.py torch.Size([3, 7, 192]) torch.Size([1, 192]) torch.Size([3, 7, 1]) Process finished with exit code 0
可以看出,输入是 3,7,192 输出是 3,7,1 ,其中 [1,192]就是这个全连接层的权重w的size, 经过转置之后变为[192,1] 。 [7,192] x [192,1] => [7,1]
以上是关于关于对PyTorch中nn.Linear的官方API文档解读的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 笔记:torch.nn.Linear() VS torch.nn.function.linear()
pytorch中的神经网络子模块(线性模块)——torch.nn.Linear
pytorch nn.Linear(对输入数据做线性变换:y=Ax+b)(全连接层?)
pytorch 是不是在 nn.Linear 中自动应用 softmax