PyTorch神经网络层拆解
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch神经网络层拆解相关的知识,希望对你有一定的参考价值。
参考技术A本文将拆解常见的PyTorch神经网络层,从开发者的角度来看,这些神经网络层都是一个一个的函数,完成对数据的处理。
第一 :CLASS torch.nn.Flatten( start_dim=1 , end_dim=- 1 ) ,将多维的输入一维化,常用在从卷积层到全连接层的过渡。需要注意的是,Flatten()的默认值start_dim=1,即默认数据数据的格式是[N,C,H,W]第0维度为Batch Size,不参与Flatten。后面的CHW全部展平为一维。
第二 , CLASS torch.nn.Linear( in_features , out_features , bias=True , device=None , dtype=None ) ,Linear又叫全连接层,TensorFlow里面叫Dense,主要用于分类。
Linear类有两个属性:
第三 ,CLASS torch.nn.Conv2d (in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=\'zeros\', device=None, dtype=None),卷积层,常用于提取图像特征,CNN+RELU+MaxPooling已经成为一种常见的特征提取操作了。
需要注意的是:CNN要求数据输入格式为:[N, Cin, Hin, Wout],Cin是输入数据Tensor的通道数量,输出为[N, Cout, Hout, Wout],Cout为本CNN层的卷积个数。Hout和Wout计算公式如下所示:
范例程序:
总结:
以上是关于PyTorch神经网络层拆解的主要内容,如果未能解决你的问题,请参考以下文章