在 pytorch 中创建自定义梯度下降
Posted
技术标签:
【中文标题】在 pytorch 中创建自定义梯度下降【英文标题】:Create custom gradient descent in pytorch 【发布时间】:2021-10-07 22:23:00 【问题描述】:我正在尝试使用 PyTorch autograd 来实现我自己的批量梯度下降算法。我想创建一个简单的单层神经网络,具有线性激活函数和均方误差作为损失函数。我似乎无法理解反向传递中到底发生了什么以及 PyTorch 如何理解我的输出。我编写了一个类,在前向传递中指定线性函数,在后向传递中,我计算了每个变量的梯度。我还为 MSE 函数编写了一个类,并在后向传递中指定了相对于 ITS 变量的梯度。当我运行一个简单的梯度下降算法时,我没有得到任何错误,但是 MSE 只在第一次迭代中下降,之后它不断上升。这让我相信我犯了一个错误,但我不确定,在哪里。有人在我的代码中看到错误吗?另外,如果有人可以向我解释 grad_output 到底代表什么,那就太棒了。
以下是函数:
import torch
from torch.autograd import Function
from torch.autograd import gradcheck
class Context:
def __init__(self):
self._saved_tensors = ()
def save_for_backward(self, *args):
self._saved_tensors = args
@property
def saved_tensors(self):
return self._saved_tensors
class MSE(Function):
@staticmethod
def forward(ctx, yhat, y):
ctx.save_for_backward(yhat, y)
q = yhat.size()[0]
mse = torch.sum((yhat-y)**2)/q
return mse
@staticmethod
def backward(ctx, grad_output):
yhat, y = ctx.saved_tensors
q = yhat.size()[0]
return 2*grad_output*(yhat-y)/q, -2*grad_output*(yhat-y)/q
class Linear(Function):
@staticmethod
def forward(ctx, X, W, b):
rows = X.size()[0]
yhat = torch.mm(X,W) + b.repeat(rows,1)
ctx.save_for_backward(yhat, X, W)
return yhat
@staticmethod
def backward(ctx, grad_output):
yhat, X, W = ctx.saved_tensors
q = yhat.size()[0]
p = yhat.size()[1]
return torch.transpose(X, 0, 1), W, torch.ones(p)
这是我的梯度下降:
import torch
from torch.utils.tensorboard import SummaryWriter
from tp1moi import MSE, Linear, Context
x = torch.randn(50, 13)
y = torch.randn(50, 3)
w = torch.randn(13, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
epsilon = 0.05
writer = SummaryWriter()
for n_iter in range(100):
linear = Linear.apply
mse = MSE.apply
loss = mse(linear(x, w, b), y)
writer.add_scalar('Loss/train', loss, n_iter)
print(f"Itérations n_iter: loss loss")
loss.backward()
with torch.no_grad():
w -= epsilon*w.grad
b -= epsilon*b.grad
w.grad.zero_()
b.grad.zero_()
这是我得到的一个输出(它们看起来都与这个相似):
Itérations 0: loss 72.99712371826172
Itérations 1: loss 7.509067535400391
Itérations 2: loss 7.309497833251953
Itérations 3: loss 7.124927997589111
Itérations 4: loss 6.955358982086182
Itérations 5: loss 6.800788402557373
Itérations 6: loss 6.661219596862793
Itérations 7: loss 6.536648750305176
Itérations 8: loss 6.427078723907471
Itérations 9: loss 6.3325090408325195
Itérations 10: loss 6.252938747406006
Itérations 11: loss 6.188369274139404
Itérations 12: loss 6.138798713684082
Itérations 13: loss 6.104228973388672
Itérations 14: loss 6.084658145904541
Itérations 15: loss 6.0800886154174805
Itérations 16: loss 6.090517520904541
Itérations 17: loss 6.115947723388672
Itérations 18: loss 6.156377792358398
Itérations 19: loss 6.2118072509765625
Itérations 20: loss 6.2822370529174805
Itérations 21: loss 6.367666721343994
Itérations 22: loss 6.468096733093262
Itérations 23: loss 6.583526611328125
Itérations 24: loss 6.713956356048584
Itérations 25: loss 6.859385967254639
Itérations 26: loss 7.019815444946289
Itérations 27: loss 7.195245742797852
Itérations 28: loss 7.385674953460693
Itérations 29: loss 7.591104507446289
Itérations 30: loss 7.811534881591797
Itérations 31: loss 8.046965599060059
Itérations 32: loss 8.297393798828125
Itérations 33: loss 8.562823295593262
Itérations 34: loss 8.843254089355469
Itérations 35: loss 9.138683319091797
Itérations 36: loss 9.449112892150879
Itérations 37: loss 9.774543762207031
Itérations 38: loss 10.114972114562988
Itérations 39: loss 10.470401763916016
Itérations 40: loss 10.840831756591797
Itérations 41: loss 11.226261138916016
Itérations 42: loss 11.626690864562988
Itérations 43: loss 12.042119979858398
Itérations 44: loss 12.472548484802246
Itérations 45: loss 12.917980194091797
Itérations 46: loss 13.378408432006836
Itérations 47: loss 13.853838920593262
Itérations 48: loss 14.344267845153809
Itérations 49: loss 14.849695205688477
Itérations 50: loss 15.370124816894531
Itérations 51: loss 15.905555725097656
Itérations 52: loss 16.455984115600586
Itérations 53: loss 17.02141571044922
Itérations 54: loss 17.601844787597656
Itérations 55: loss 18.19727325439453
Itérations 56: loss 18.807701110839844
Itérations 57: loss 19.43313217163086
Itérations 58: loss 20.07356071472168
Itérations 59: loss 20.728988647460938
Itérations 60: loss 21.3994197845459
Itérations 61: loss 22.084848403930664
Itérations 62: loss 22.7852783203125
Itérations 63: loss 23.50070571899414
Itérations 64: loss 24.23113441467285
Itérations 65: loss 24.9765625
Itérations 66: loss 25.73699188232422
Itérations 67: loss 26.512422561645508
Itérations 68: loss 27.302854537963867
Itérations 69: loss 28.108285903930664
Itérations 70: loss 28.9287166595459
Itérations 71: loss 29.764144897460938
Itérations 72: loss 30.614578247070312
Itérations 73: loss 31.48000717163086
Itérations 74: loss 32.36043930053711
Itérations 75: loss 33.2558708190918
Itérations 76: loss 34.16630172729492
Itérations 77: loss 35.091732025146484
Itérations 78: loss 36.032161712646484
Itérations 79: loss 36.98759460449219
Itérations 80: loss 37.95802307128906
Itérations 81: loss 38.943458557128906
Itérations 82: loss 39.943885803222656
Itérations 83: loss 40.959320068359375
Itérations 84: loss 41.98974609375
Itérations 85: loss 43.03517532348633
Itérations 86: loss 44.09561538696289
Itérations 87: loss 45.171043395996094
Itérations 88: loss 46.261474609375
Itérations 89: loss 47.366905212402344
Itérations 90: loss 48.487335205078125
Itérations 91: loss 49.62276840209961
Itérations 92: loss 50.773197174072266
Itérations 93: loss 51.93863296508789
Itérations 94: loss 53.11906433105469
Itérations 95: loss 54.31448745727539
Itérations 96: loss 55.524925231933594
Itérations 97: loss 56.75035095214844
Itérations 98: loss 57.990787506103516
Itérations 99: loss 59.2462158203125```
【问题讨论】:
【参考方案1】:让我们看一下MSE
的实现,前向传递将是MSE(y, y_hat) = (y_hat-y)²
,这很简单。对于后向传播,我们希望计算输出对输入的导数,以及对每个参数的导数。这里MSE
没有任何学习参数,所以我们只想使用链式法则计算dMSE/dy*dz/dMSE
,即d(y_hat-y)²/dy*dz/dMSE
,即-2(y_hat-y)*dz/dMSE
。不要在这里混淆你:我写了dz/dMSE
作为传入的渐变。它对应于向后朝向 MSE 层的梯度。从你的符号grad_output
是 dz/dMSE
。因此向后传递只是-2*(y_hat-y)*grad_output
。然后通过从y_hat.size(0)
检索到的批量大小q
进行归一化。
Linear
层也是如此。它将涉及更多计算,因为这一次,该层由w
和b
参数化。前向传球本质上是x@w + b
。而反向传递,包括计算dz/dx
、dz/dw
和dz/db
。将f
写为x@w + b
。经过一些工作,您会发现:
dz/dx = d(x@w + b)/dx * dz/df = dz/df*W.T
,
dz/dw = d(x@w + b)/dw * dz/df = X.T*dz/df
,
dz/db = d(x@w + b)/db * dz/df = 1
.
在实现方面,这看起来像:
output_grad@w.T
用于渐变 w.r.t x
,
x.T@output_grad
用于渐变 w.r.t w
,
torch.ones_like(b)
用于渐变 w.r.t b
。
【讨论】:
以上是关于在 pytorch 中创建自定义梯度下降的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch学习2B站刘二大人《PyTorch深度学习实践》——梯度下降算法(Gradient Descent)