关于对PyTorch中nn.Linear的官方API文档解读

Posted 小魏同学

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了关于对PyTorch中nn.Linear的官方API文档解读相关的知识,希望对你有一定的参考价值。

torch.nn.Linear(in_features, out_features, bias=True)

1.1 作用

 

pytorch中的全连接层,对于输入的向量x进行线性变换y=xAT+b

 

1.2 参数

 

in_features – 输入的向量的size

 

out_features – 最终返回的输出向量的size

 

bias – If set to False, the layer will not learn an additive bias. Default: True

 

1.3 形状

 

  • Input: (N,∗,Hin) where * means any number of additional dimensions and Hin=in_features

  • Output: (N,∗,Hout) where all but the last dimension are the same shape as the input and Hout=out_features .

 

意思就是输入输出向量不一定非要是两维,保证第一维相同,最后一个dim能和in_features,out_features匹配就行了

 

from torch import nn
import torch

linear=nn.Linear(in_features=64*3,out_features=1)

a=torch.rand(3,7,64*3)

print(a.shape)
print(linear.weight.shape)
b=linear(a)
print(b.shape)

 

/Users/weihaoyang/opt/anaconda3/envs/nlp_chat/bin/python /Users/weihaoyang/PycharmProjects/pytorch练习/test.py
torch.Size([3, 7, 192])
torch.Size([1, 192])
torch.Size([3, 7, 1])

Process finished with exit code 0

 

可以看出,输入是 3,7,192 输出是 3,7,1 ,其中 [1,192]就是这个全连接层的权重w的size, 经过转置之后变为[192,1] 。 [7,192] x [192,1] => [7,1]




以上是关于关于对PyTorch中nn.Linear的官方API文档解读的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 笔记:torch.nn.Linear() VS torch.nn.function.linear()

pytorch中的神经网络子模块(线性模块)——torch.nn.Linear

pytorch nn.Linear(对输入数据做线性变换:y=Ax+b)(全连接层?)

pytorch 是不是在 nn.Linear 中自动应用 softmax

如何在 Pytorch 的“nn.Sequential”中展平输入

pytorch之求梯度和nn.Linear的理解