在pytorch神经网络中初始化权重
Posted
技术标签:
【中文标题】在pytorch神经网络中初始化权重【英文标题】:Initialize weight in pytorch neural net 【发布时间】:2021-06-25 00:47:43 【问题描述】:我已经创建了这个神经网络:
class _netD(nn.Module):
def __init__(self, num_classes=1, nc=1, ndf=64):
super(_netD, self).__init__()
self.num_classes = num_classes
# nc is number of channels
# num_classes is number of classes
# ndf is the number of output channel at the first layer
self.main = nn.Sequential(
# input is (nc) x 28 x 28
# conv2D(in_channels, out_channels, kernelsize, stride, padding)
nn.Conv2d(nc, ndf , 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 14 x 14
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 7 x 7
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 3 x 3
nn.Conv2d(ndf * 4, ndf * 8, 3, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 2 x 2
nn.Conv2d(ndf * 8, num_classes, 2, 1, 0, bias=False),
# out size = batch x num_classes x 1 x 1
)
if self.num_classes == 1:
self.main.add_module('prob', nn.Sigmoid())
# output = probability
else:
pass
# output = scores
def forward(self, input):
output = self.main(input)
return output.view(input.size(0), self.num_classes).squeeze(1)
我想遍历不同的层并根据层的类型应用权重初始化。我正在尝试执行以下操作:
D = _netD()
for name, param in D.named_parameters():
if type(param) == nn.Conv2d:
param.weight.normal_(...)
但这不起作用。你能帮帮我吗?
谢谢
【问题讨论】:
【参考方案1】:type(param)
只会为模型中的任何类型的权重或数据返回称为parameter
的实际数据类型。因为named_parameters()
在基于nn.sequential
的模型上使用时也不会返回任何有用的名称,因此您需要查看模块以查看哪些层与使用isinstance
的nn.Conv2d 类特别相关比如:
for layer in D.modules():
if isinstance(layer, nn.Conv2d):
layer.weight.data.normal_(...)
或者,Soumith Chintala 本人推荐的方式,实际上只是循环通过您的主模块本身:
for L,layer in D.main:
if isisntance(layer,nn.Conv2d):
layer.weight.data.normal_(..)
我实际上更喜欢第一个,因为您不必指定确切的 nn.sequential 模块本身,并且会搜索模型中所有可能的模块,但任何一个都应该为您完成这项工作。
【讨论】:
以上是关于在pytorch神经网络中初始化权重的主要内容,如果未能解决你的问题,请参考以下文章