自定义参数初始化方法
Posted baitian963
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了自定义参数初始化方法相关的知识,希望对你有一定的参考价值。
def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) # 也可以判断是否为conv2d,使用相应的初始化方式 elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode=‘fan_out‘, nonlinearity=‘relu‘) # 是否为批归一化层 elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # 2. 初始化网络结构 model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim) # 3. 将weight_init应用在子模块上 model.apply(weight_init)
自定义参数初始化方法
原博客:https://blog.csdn.net/dss_dssssd/article/details/83990511
def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) # 也可以判断是否为conv2d,使用相应的初始化方式 elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode=‘fan_out‘, nonlinearity=‘relu‘) # 是否为批归一化层 elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)# 2. 初始化网络结构 model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)# 3. 将weight_init应用在子模块上model.apply(weight_init)————————————————版权声明:本文为CSDN博主「墨氲」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。原文链接:https://blog.csdn.net/dss_dssssd/article/details/83990511
以上是关于自定义参数初始化方法的主要内容,如果未能解决你的问题,请参考以下文章