Pytorch学习率相关问题及L2 Penalty

Posted PyTorch与深度学习

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch学习率相关问题及L2 Penalty相关的知识,希望对你有一定的参考价值。

1.Pytorch对于不同层设置不同的学习率

 
   
   
 
  1. optim.SGD([

  2.                {'params': model.base.parameters()},

  3.                {'params': model.classifier.parameters(), 'lr': 1e-3}

  4.            ], lr=1e-2, momentum=0.9)

  • model.base.parameters()将使用1e-2的学习率,

  • model.classifier.parameters()将使用1e-3的学习率,

  • 0.9的momentum作用于所有的parameters。

2.甚至可以对于同一层的weight和bias设置不同的lr

 
   
   
 
  1. import torch

  2. import torch.nn as nn

  3. import torch.optim as optim

  4. from torch.autograd import Variable

  5. class Net(nn.Module):

  6.    def __init__(self):

  7.        super(Net, self).__init__()

  8.        self.layer = nn.Linear(1, 1)

  9.    def forward(self, x):

  10.        return self.layer(x)

  11. if __name__=="__main__":

  12.    net = Net()

  13.    optimizer = optim.Adam([

  14.                {'params': net.layer.weight},

  15.                {'params': net.layer.bias, 'lr': 0.0001}

  16.            ], lr=0.01, weight_decay=0.0001)

  17.    out = net(Variable(torch.Tensor(1)))

  18.    out.backward()

  19.    optimizer.step()

3.训练过程中学习率的调节

torch.optim.lr_scheduler中提供了多种方法,如

  • LambdaLR

  • StepLR

  • MultiStepLR

  • ExponentialLR 

  • 举个官网给的例子

 
   
   
 
  1. >>> # Assuming optimizer uses lr = 0.05 for all groups

  2. >>> # lr = 0.05     if epoch < 30

  3. >>> # lr = 0.005    if 30 <= epoch < 80

  4. >>> # lr = 0.0005   if epoch >= 80

  5. >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

  6. >>> for epoch in range(100):

  7. >>>     scheduler.step()

  8. >>>     train(...)

  9. >>>     validate(...)

4.对参数加 L2 penalty

Pytorch中的Adam,RMSprop,SGD等都有一个weight_decay的参数,默认为0,如果不为0,对参数会施加L2 penalty。

optimizer=torch.optim.Adam(model.parameters(),lr=1e-4,weight_decay=1e-5)


以上是关于Pytorch学习率相关问题及L2 Penalty的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch 模型 查看网络参数的梯度以及参数更新是否正确,优化器学习率的分层设置

如何在 Pytorch 中应用分层学习率?

pytorch常用优化器总结(包括warmup介绍及代码实现)

pytorch常用优化器总结(包括warmup介绍及代码实现)

pytorch常用优化器总结(包括warmup介绍及代码实现)

Pytorch Note34 学习率衰减