在 pytorch 中创建自定义梯度下降

Posted

技术标签:

【中文标题】在 pytorch 中创建自定义梯度下降【英文标题】:Create custom gradient descent in pytorch 【发布时间】:2021-10-07 22:23:00 【问题描述】:

我正在尝试使用 PyTorch autograd 来实现我自己的批量梯度下降算法。我想创建一个简单的单层神经网络,具有线性激活函数和均方误差作为损失函数。我似乎无法理解反向传递中到底发生了什么以及 PyTorch 如何理解我的输出。我编写了一个类,在前向传递中指定线性函数,在后向传递中,我计算了每个变量的梯度。我还为 MSE 函数编写了一个类,并在后向传递中指定了相对于 ITS 变量的梯度。当我运行一个简单的梯度下降算法时,我没有得到任何错误,但是 MSE 只在第一次迭代中下降,之后它不断上升。这让我相信我犯了一个错误,但我不确定,在哪里。有人在我的代码中看到错误吗?另外,如果有人可以向我解释 grad_output 到底代表什么,那就太棒了。

以下是函数:

import torch
from torch.autograd import Function
from torch.autograd import gradcheck


class Context:
    def __init__(self):
        self._saved_tensors = ()

    def save_for_backward(self, *args):
        self._saved_tensors = args

    @property
    def saved_tensors(self):
        return self._saved_tensors

class MSE(Function):
    @staticmethod
    def forward(ctx, yhat, y):
        ctx.save_for_backward(yhat, y)
        q = yhat.size()[0]
        mse = torch.sum((yhat-y)**2)/q
        return mse

    @staticmethod
    def backward(ctx, grad_output):
        yhat, y = ctx.saved_tensors
        q = yhat.size()[0]
        return 2*grad_output*(yhat-y)/q, -2*grad_output*(yhat-y)/q  

class Linear(Function):
    @staticmethod
    def forward(ctx, X, W, b):
        rows = X.size()[0]
        yhat = torch.mm(X,W) + b.repeat(rows,1)
        ctx.save_for_backward(yhat, X, W)
        return yhat

    @staticmethod
    def backward(ctx, grad_output):
        yhat, X, W = ctx.saved_tensors
        q = yhat.size()[0]
        p = yhat.size()[1]
        return torch.transpose(X, 0, 1), W, torch.ones(p)

这是我的梯度下降:

import torch
from torch.utils.tensorboard import SummaryWriter       
from tp1moi import MSE, Linear, Context

 
x = torch.randn(50, 13)         
y = torch.randn(50, 3)

w = torch.randn(13, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)

epsilon = 0.05

writer = SummaryWriter()
for n_iter in range(100):
    linear = Linear.apply
    mse = MSE.apply
    loss = mse(linear(x, w, b), y)

    writer.add_scalar('Loss/train', loss, n_iter)

    print(f"Itérations n_iter: loss loss")

    loss.backward()

    with torch.no_grad(): 
        w -= epsilon*w.grad
        b -= epsilon*b.grad
        w.grad.zero_()
        b.grad.zero_()

这是我得到的一个输出(它们看起来都与这个相似):

Itérations 0: loss 72.99712371826172
Itérations 1: loss 7.509067535400391
Itérations 2: loss 7.309497833251953
Itérations 3: loss 7.124927997589111
Itérations 4: loss 6.955358982086182
Itérations 5: loss 6.800788402557373
Itérations 6: loss 6.661219596862793
Itérations 7: loss 6.536648750305176
Itérations 8: loss 6.427078723907471
Itérations 9: loss 6.3325090408325195
Itérations 10: loss 6.252938747406006
Itérations 11: loss 6.188369274139404
Itérations 12: loss 6.138798713684082
Itérations 13: loss 6.104228973388672
Itérations 14: loss 6.084658145904541
Itérations 15: loss 6.0800886154174805
Itérations 16: loss 6.090517520904541
Itérations 17: loss 6.115947723388672
Itérations 18: loss 6.156377792358398
Itérations 19: loss 6.2118072509765625
Itérations 20: loss 6.2822370529174805
Itérations 21: loss 6.367666721343994
Itérations 22: loss 6.468096733093262
Itérations 23: loss 6.583526611328125
Itérations 24: loss 6.713956356048584
Itérations 25: loss 6.859385967254639
Itérations 26: loss 7.019815444946289
Itérations 27: loss 7.195245742797852
Itérations 28: loss 7.385674953460693
Itérations 29: loss 7.591104507446289
Itérations 30: loss 7.811534881591797
Itérations 31: loss 8.046965599060059
Itérations 32: loss 8.297393798828125
Itérations 33: loss 8.562823295593262
Itérations 34: loss 8.843254089355469
Itérations 35: loss 9.138683319091797
Itérations 36: loss 9.449112892150879
Itérations 37: loss 9.774543762207031
Itérations 38: loss 10.114972114562988
Itérations 39: loss 10.470401763916016
Itérations 40: loss 10.840831756591797
Itérations 41: loss 11.226261138916016
Itérations 42: loss 11.626690864562988
Itérations 43: loss 12.042119979858398
Itérations 44: loss 12.472548484802246
Itérations 45: loss 12.917980194091797
Itérations 46: loss 13.378408432006836
Itérations 47: loss 13.853838920593262
Itérations 48: loss 14.344267845153809
Itérations 49: loss 14.849695205688477
Itérations 50: loss 15.370124816894531
Itérations 51: loss 15.905555725097656
Itérations 52: loss 16.455984115600586
Itérations 53: loss 17.02141571044922
Itérations 54: loss 17.601844787597656
Itérations 55: loss 18.19727325439453
Itérations 56: loss 18.807701110839844
Itérations 57: loss 19.43313217163086
Itérations 58: loss 20.07356071472168
Itérations 59: loss 20.728988647460938
Itérations 60: loss 21.3994197845459
Itérations 61: loss 22.084848403930664
Itérations 62: loss 22.7852783203125
Itérations 63: loss 23.50070571899414
Itérations 64: loss 24.23113441467285
Itérations 65: loss 24.9765625
Itérations 66: loss 25.73699188232422
Itérations 67: loss 26.512422561645508
Itérations 68: loss 27.302854537963867
Itérations 69: loss 28.108285903930664
Itérations 70: loss 28.9287166595459
Itérations 71: loss 29.764144897460938
Itérations 72: loss 30.614578247070312
Itérations 73: loss 31.48000717163086
Itérations 74: loss 32.36043930053711
Itérations 75: loss 33.2558708190918
Itérations 76: loss 34.16630172729492
Itérations 77: loss 35.091732025146484
Itérations 78: loss 36.032161712646484
Itérations 79: loss 36.98759460449219
Itérations 80: loss 37.95802307128906
Itérations 81: loss 38.943458557128906
Itérations 82: loss 39.943885803222656
Itérations 83: loss 40.959320068359375
Itérations 84: loss 41.98974609375
Itérations 85: loss 43.03517532348633
Itérations 86: loss 44.09561538696289
Itérations 87: loss 45.171043395996094
Itérations 88: loss 46.261474609375
Itérations 89: loss 47.366905212402344
Itérations 90: loss 48.487335205078125
Itérations 91: loss 49.62276840209961
Itérations 92: loss 50.773197174072266
Itérations 93: loss 51.93863296508789
Itérations 94: loss 53.11906433105469
Itérations 95: loss 54.31448745727539
Itérations 96: loss 55.524925231933594
Itérations 97: loss 56.75035095214844
Itérations 98: loss 57.990787506103516
Itérations 99: loss 59.2462158203125```

【问题讨论】:

【参考方案1】:

让我们看一下MSE的实现,前向传递将是MSE(y, y_hat) = (y_hat-y)²,这很简单。对于后向传播,我们希望计算输出对输入的导数,以及对每个参数的导数。这里MSE 没有任何学习参数,所以我们只想使用链式法则计算dMSE/dy*dz/dMSE,即d(y_hat-y)²/dy*dz/dMSE-2(y_hat-y)*dz/dMSE。不要在这里混淆你:我写了dz/dMSE作为传入的渐变。它对应于向后朝向 MSE 层的梯度。从你的符号grad_output dz/dMSE。因此向后传递只是-2*(y_hat-y)*grad_output。然后通过从y_hat.size(0)检索到的批量大小q进行归一化。

Linear 层也是如此。它将涉及更多计算,因为这一次,该层由wb 参数化。前向传球本质上是x@w + b。而反向传递,包括计算dz/dxdz/dwdz/db。将f 写为x@w + b。经过一些工作,您会发现:

dz/dx = d(x@w + b)/dx * dz/df = dz/df*W.T, dz/dw = d(x@w + b)/dw * dz/df = X.T*dz/df, dz/db = d(x@w + b)/db * dz/df = 1.

在实现方面,这看起来像:

output_grad@w.T 用于渐变 w.r.t xx.T@output_grad 用于渐变 w.r.t wtorch.ones_like(b) 用于渐变 w.r.t b

【讨论】:

以上是关于在 pytorch 中创建自定义梯度下降的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch学习2B站刘二大人《PyTorch深度学习实践》——梯度下降算法(Gradient Descent)

Pytorch之梯度下降和方向传播理论介绍

PyTorch深度学习-03梯度下降(快速入门)

Pytorch Note9 线性模型和梯度下降

Pytorch Note15 优化算法1 梯度下降(Gradient descent varients)

pytorch学习笔记:梯度下降