02-pytorch
Posted liu247
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了02-pytorch相关的知识,希望对你有一定的参考价值。
import torch
from pprint import pprint
from torch.autograd import Variable
变成图纸中的一个节点
tensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor,requires_grad=True)
pprint(tensor)
pprint(variable)
tensor([[1., 2.],
[3., 4.]])
tensor([[1., 2.],
[3., 4.]], requires_grad=True)
反向传播误差
t_out = torch.mean(tensor*tensor) # 求x^2的平均值
v_out = torch.mean(variable*variable)
# v_out = 1/4 *sum(var*var)
v_out.backward() # 反向传播误差
# d(v_out)/d(var) 1/4*2*variable = variable/2
print(variable.grad)
tensor([[0.5000, 1.0000],
[1.5000, 2.0000]])
# variable.data 才是 tensor 的形式
print(variable.data)
tensor([[1., 2.],
[3., 4.]])
以上是关于02-pytorch的主要内容,如果未能解决你的问题,请参考以下文章