[基于Pytorch的MNIST识别01]神经网络建立
Posted AIplusX
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[基于Pytorch的MNIST识别01]神经网络建立相关的知识,希望对你有一定的参考价值。
写在前面
前面我曾尝试在无框架的情况下进行神经网络的构建和调参,我发现虽然网络构建起来和运行起来都问题不大,但是在调参时就会显现无框架的弊端。经过初步的调参之后,我建立的网络识别准确率只能达到45%,但是我反复演算和查看公式、代码,都没有发现问题,因此调试就十分的困难了。
所以在经过思索和衡量之后,我决定还是采用一个轻量级的工具来帮助我调试神经网络。经过我的对比搜索之后,我决定使用pytorch,因为他兼容numpy且使用部署起来较为简便,便于我将精力投入到神经网络本身的学习,尽量减少环境对我的影响。
今天的工作
今天主要学习了pytorch的基本语法,并且将原来的BPnet.py部分文件进行了基于python的复现,主要是利用类来实现神经网络,这样可以使得代码可读性更高,先放出来看一下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#print beta
torch.__version__
#hyperparameters
input_size = 784 #28*28
hidden_size = 16
batch_size = 100
out_put_size = 10 #0~9
#set up NeuralNet
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, out_put_size):
super(NeuralNet, self).__init__()
#recode hyperparameters
self.input_size = input_size
self.hidden_size = hidden_size
self.hidden_layer_size = hidden_layer_size
self.out_put_size = out_put_size
# 2 hidden_layers
self.gap0 = nn.Linear(input_size, hidden_size)
self.gap1 = nn.Linear(hidden_size, hidden_size)
self.gap2 = nn.Linear(hidden_size, out_put_size)
def forward(self, x):
out = self.gap0(x)
out = torch.relu(out)
out = self.gap1(out)
out = torch.rele(out)
out = self.gap2(out)
out = torch.sigmoid(out)
return out
# net = NeuralNet(input_size, hidden_size, out_put_size)
基本神经网络的结果还是如下图所示:
可以看到在借助了pytorch框架之后,我们就不需要拘泥于神经网络的层数和隐层内的神经元的数量了,基于pytorch框架的话可以很轻易,很清晰的进行修改,因此我就没有特别的将隐层内的层数和隐层内的神经元数量进行标注。
在这个初始的例子中,我还是沿用了之前的方法,2个隐层,每个隐层都是16个神经元,根据之后神经网络的表现可以灵活优化结构。
其中nn.Linear()
可以理解成神经网络的前层和后层之间的矩阵相乘的操作,而且这个函数会记录神经网络前向传递的过程,并且做为parameters参数保存下来。
我们可以用如下语句进行查看:
print(net)
for name,parameters in net.named_parameters():
print(name,':',parameters.size())
可以得到如下图所示的结果:
查看parameters参数之后可以发现神经网络的每层的参数都记录了下来,这样就可以跟踪神经网络的前向传递操作从而记录得到梯度,最终可以根据这个计算出来的梯度进行参数更新。
明天的工作
1、明天主要是进行MNIST数据的加载,并且将数据转换成Tensor;
2、将变量移动到GPU上从而加快模型训练速度;
以上是关于[基于Pytorch的MNIST识别01]神经网络建立的主要内容,如果未能解决你的问题,请参考以下文章
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!