如何入门Pytorch之二:如何搭建实用神经网络
Posted jimchen1218
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何入门Pytorch之二:如何搭建实用神经网络相关的知识,希望对你有一定的参考价值。
在上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等。
本节中,我们将介绍如何使用Pytorch来搭建一个实用的神经网络。
搭建一个神经网络并训练,大致有这么四个部分:
准备数据,搭建模型,评估函数,优化网络权重
1.数据准备
数据准备在上一篇中已讲过,这里就不多赘述了。
2.搭建模型
层(神经网络的基本组建单元)
针对y=wx+b,搭建简单线性模型
from torch.nn import Linear
inp = Variable(torch.randn(1,10))
myLayer = Linear(in_features=10,out_features=5,bias=True)
myLayer(inp)
myLayer.weight #通过权重来访问
Output :
Parameter containing:
-0.2386 0.0828 0.2904 0.3133 0.2037 0.1858 -0.2642 0.2862 0.2874 0.1141
0.0512 -0.2286 -0.1717 0.0554 0.1766 -0.0517 0.3112 0.0980 -0.2364 -0.0442
0.0776 -0.2169 0.0183 -0.0384 0.0606 0.2890 -0.0068 0.2344 0.2711 -0.3039
0.1055 0.0224 0.2044 0.0782 0.0790 0.2744 -0.1785 -0.1681 -0.0681 0.3141
0.2715 0.2606 -0.0362 0.0113 0.1299 -0.1112 -0.1652 0.2276 0.3082 -0.2745
[torch.FloatTensor of size 5x10]
myLayer.bias #通过偏置来访问
Output :
Parameter containing:
-0.2646
-0.2232
0.2444
0.2177
0.0897
[torch.FloatTensor of size 5
线性层在不同的框架里有着不同的名字,如:Dense或fully connected layers.
很多时候,为了解决现实世界问题,需要搭建多个层:
myLayer1 = Linear(10,5) myLayer2 = Linear(5,2) myLayer2(myLayer1(inp))
每层的参数都不一样:
Layers Weight1 Layer1 3.0 Layer2 2.0
不过,只是简单的堆叠线性层并不能有效帮助网络学习到更多新的东西。例如:
Y=2(3X1)-2LinearLayers
Y=6(X1)-1 LinearLayers
为了解决上面这个问题,就需要引入非线性函数。以下是一些常用的线性函数:
Sigmoid :f(x)=1 /(1+e^(-x))
Tanh:
ReLU:f(x)=max(0,x)
Leaky ReLU
sample_data = Variable(torch.Tensor([[1,2,-1,-1]])) myRelu = ReLU() myRelu(sample_data) Output: Variable containing: 1 2 0 0 [torch.FloatTensor of size 1x4]
搭建一个深度学习算法
基于nn.Module搭建的简易模型类
class MyFirstNetwork(nn.Module): def __init__(self,input_size,hidden_size,output_size): super(MyFirstNetwork,self).__init__() self.layer1 = nn.Linear(input_size,hidden_size) self.layer2 = nn.Linear(hidden_size,output_size) def __forward__(self,input): out = self.layer1(input) out = nn.ReLU(out) out = self.layer2(out) return out
机器学习主要解决的问题:分类,回归,多分类。
1、回归:使用线性网络的最近一层的其中一个输出值。
2、分类:使用sigmoid激活函数(近0或1)作为最终输出值。也即2分类问题。
3、多分类:使用softmax层作为最终输出。
以上是关于如何入门Pytorch之二:如何搭建实用神经网络的主要内容,如果未能解决你的问题,请参考以下文章