『PyTorch』第六弹_最小二乘法的不同实现手段(待续)

Posted 叠加态的猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了『PyTorch』第六弹_最小二乘法的不同实现手段(待续)相关的知识,希望对你有一定的参考价值。

PyTorch的Variable

import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
from IPython import display

# 指定随机数种子
t.manual_seed(1000)

def get_fake_data(batch_size=8):
    x = t.rand(batch_size,1)*20
    y = x * 2 + 3 + 3*t.randn(batch_size,1)
    return x, y

x, y = get_fake_data()
plt.scatter(x.squeeze(), y.squeeze())

w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)

lr = 0.001

for ii in range(8000):
    x, y = get_fake_data()
    x, y = V(x), V(y)
    # print(x, y)
    y_pred = x.mm(w) + b.expand_as(x)

    loss = 0.5*(y_pred - y)**2
    loss = loss.sum()  # 集结loss向量

    loss.backward()

    w.data.sub_(lr * w.grad.data)
    b.data.sub_(lr * b.grad.data)

    w.grad.data.zero_()
    b.grad.data.zero_()

    if ii % 1000 == 0:
        display.clear_output(wait=True)
        x = t.arange(0,20).view(-1,1)
        y = x.mm(w.data) + b.data.expand_as(x)
        plt.plot(x.numpy(), y.numpy())
        x2, y2 = get_fake_data(batch_size=20)
        plt.scatter(x2, y2)

        plt.xlim(0,20)
        plt.ylim(0,40)
        plt.show()
        
print(w.data.squeeze(), b.data.squeeze())

 

以上是关于『PyTorch』第六弹_最小二乘法的不同实现手段(待续)的主要内容,如果未能解决你的问题,请参考以下文章

机器学习-最小二乘法

Mybatis 踩坑第六弹—缓存

机器学习之用Python实现最小二乘法预测房价,进行额度预测

01_有监督学习--简单线性回归模型(最小二乘法代码实现)

Java常见面试题(第六弹):分布式锁的实现方式有哪三种?

css学习の第六弹—样式设置小技巧