pytorch 笔记:model.apply

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:model.apply相关的知识,希望对你有一定的参考价值。

network.apply(func)

——在每个子模组递归地执行func

——一般用于初始化参数中

@torch.no_grad()
def init_weights(m):
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        m.bias.fill_(0)
    if type(m)==nn.Conv2d:
        m.weight.fill_(4.9)
net = nn.Sequential(nn.Linear(2, 2), nn.Conv2d(2,2,1))
net.apply(init_weights)
for i in net.parameters():
    print(i)

'''
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)
Parameter containing:
tensor([[[[4.9000]],

         [[4.9000]]],


        [[[4.9000]],

         [[4.9000]]]], requires_grad=True)
Parameter containing:
tensor([0.3902, 0.0678], requires_grad=True)
'''

以上是关于pytorch 笔记:model.apply的主要内容,如果未能解决你的问题,请参考以下文章