Paddle 环境中 使用LeNet在MNIST数据集实现图像分类
Posted 卓晴
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Paddle 环境中 使用LeNet在MNIST数据集实现图像分类相关的知识,希望对你有一定的参考价值。
简 介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例。基于可以搭建其他网络程序。
关键词
: MNIST,Paddle,LeNet
- 作者: PaddlePaddle
- 日期: 2021.12
- 摘要: 本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。
§01 环境配置
本教程基于Paddle 2.2 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.2。
import paddle
print(paddle.__version__)
2.2.1
§02 数据加载
手写数字的MNIST
数据集,包含60,000
个用于训练的示例和10,000
个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28
像素)
,其值为0
到1
。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist
。
一、加载mnist数据集合
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST
完成mnist
数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
download training data and load training data
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz
Begin to download
item 8/8 [============================>.] - ETA: 0s - 4ms/item
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz
Begin to download
Download finished
item 95/403 [======>.......................] - ETA: 0s - 2ms/item
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download
item 2/2 [===========================>..] - ETA: 0s - 2ms/item
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download
Download finished
load finished
二、查看数据图像
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
train_data0 label is: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
'a.item() instead', DeprecationWarning, stacklevel=1)
▲ 图2.2.1 训练结合中的图片
§03 建立网络
用paddle.nn
下的API
,如Conv2D
、MaxPool2D
、Linear
完成LeNet
的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
§04 网络训练
一、方式1:基于高层API
通过paddle
提供的Model
构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
1、使用Model.fit完成模型训练
方式1:基于高层API,完成模型的训练与预测
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
▲ 图4.1.1 训练两个周期
运行时长:42.963
秒结束时间:2021-12-12 23:08:16
2、使用Model.evaluate进行模型预测
model.evaluate(test_dataset, batch_size=64, verbose=1)
▲ 图4.1.2 预测模型
'loss': [0.0013720806], 'acc': 0.9848
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
上述训练过程是在普通CPU的模式下进行的。
二、方式基于基础API
方式2:基于基础API,完成模型的训练与预测
1、模型训练
组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
def train(model):
model.train()
epochs = 使用LeNet网络进行MNIST图像分类
AI Studio : 利用Paddle框架中的极简框架识别MNIST
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!