在Pytorch中实现WNGrad?
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在Pytorch中实现WNGrad?相关的知识,希望对你有一定的参考价值。
我正在尝试在pytorch中实现WNGrad(技术上WN-Adam,论文中的算法4)optimizier(WNGrad)。我之前从未在pytorch中实现过优化器,因此我不知道我是否已正确完成(我从adam实现开始)。优化器没有取得太大进展并且像我预期的那样下降(bj值只能单调增加,这很快发生,所以没有进展)但我猜我有一个bug。标准优化器(Adam,SGD)在我试图优化的同一模型上工作正常。
这种实现看起来是否正确?
from torch.optim import Optimizer
class WNAdam(Optimizer):
"""Implements WNAdam algorithm.
It has been proposed in `WNGrad: Learn the Learning Rate in Gradient Descent`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 0.1)
beta1 (float, optional): exponential smoothing coefficient for gradient.
When beta=0 this implements WNGrad.
.. _WNGrad: Learn the Learning Rate in Gradient Descent:
https://arxiv.org/abs/1803.02865
"""
def __init__(self, params, lr=0.1, beta1=0.9):
if not 0.0 <= beta1 < 1.0:
raise ValueError("Invalid beta1 parameter: {}".format(beta1))
defaults = dict(lr=lr, beta1=beta1)
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Learning rate adjustment
state['bj'] = 1.0
exp_avg = state['exp_avg']
beta1 = group['beta1']
state['step'] += 1
state['bj'] += (group['lr']**2)/(state['bj'])*grad.pow(2).sum()
# update exponential moving average
exp_avg.mul_(beta1).add_(1 - beta1, grad)
bias_correction = 1 - beta1 ** state['step']
p.data.sub_(group['lr'] / state['bj'] / bias_correction, exp_avg)
return loss
答案
WNGrad论文表明它受批量(和重量)标准化的启发。你应该使用关于权重维度的L2范数(不要总结)作为show in this algorithm
以上是关于在Pytorch中实现WNGrad?的主要内容,如果未能解决你的问题,请参考以下文章