PyTorch学习系列——参数_初始化

Posted Vic时代

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习系列——参数_初始化相关的知识,希望对你有一定的参考价值。

上篇文章介绍了神经网络的参数定义,并简单提到了其默认初始化。这里详细聊聊怎么根据用户的需求对参数进行初始化。

PyTorch提供了多种参数初始化函数:

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。

示例:

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
init.xavier_uniform(self.conv1.weight)
init.constant(self.conv1.bias, 0.1)

上面的语句是对网络的某一层参数进行初始化。如何对整个网络的参数进行初始化定制呢?

def weights_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv') != -1:
        xavier(m.weight.data)
        xavier(m.bias.data)
net = Net()
net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。   

不建议访问以下划线为前缀的成员,他们是内部的,如果有改变不会通知用户。更推荐的一种方法是检查某个module是否是某种类型:

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        xavier(m.bias.data)     

参考

[1] https://discuss.pytorch.org/t/weight-initilzation/157

以上是关于PyTorch学习系列——参数_初始化的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch参数初始化--默认与自定义

pytorch 参数初始化

pytorch对模型参数初始化

pytorch-卷积基本网络结构-提取网络参数-初始化网络参数

『PyTorch』第十三弹_torch.nn.init参数初始化

pytorch学习系列文章第二篇——张量