pytorch lightning 手写数字分类实例
Posted Tina姐
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch lightning 手写数字分类实例 相关的知识,希望对你有一定的参考价值。
今天通过手写数字来学习如何利用pytorch-lightning进行分类
代码同第二部分的差不多,新增了断点训练和测试部分。
项目使用jupyter notebook演示
此部分代码很简单,小白也能上手,赶快来试一试吧~~~
该系列还有
pytorch-lightning入门(一)—— 初了解
如何从Pytorch 到 Pytorch Lightning (二) | 简要介绍
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
from pytorch_lightning import seed_everything
import numpy as np
import matplotlib.pyplot as plt
SET SEED
# 首先设置随机数种子
seed_everything(seed=42)
# 定义模型
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1 (b, 1*28*28) -> (b, 128)
x = self.layer_1(x)
x = torch.relu(x)
# layer 2 (b, 128) -> (b, 256)
x = self.layer_2(x)
x = torch.relu(x)
# layer 3 (b, 256) -> (b, 10)
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
dataloader
# data
# transforms for images
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=64)
val_loader = DataLoader(mnist_test, batch_size=64)
没有下载的会自动进行下载,如果速度慢,就手动下载保存。有VPN 可以直接运行代码
training
接下来开始训练,提供两种训练方法
- 从头训练
- 从断点开始训练
# train
model = LightningMNISTClassifier()
# resume training
RESUME = False
if RESUME:
resume_checkpoint_dir = './lightning_logs/version_1/checkpoints/'
checkpoint_path = os.listdir(resume_checkpoint_dir)[0]
resume_checkpoint_path = resume_checkpoint_dir + checkpoint_path
trainer = pl.Trainer(gpus='1',
max_epochs=10,
resume_from_checkpoint=resume_checkpoint_path)
trainer.fit(model, train_dataloader, val_loader)
else:
trainer = pl.Trainer(gpus='1', max_epochs=20)
trainer.fit(model, train_dataloader, val_loader)
输出包括:
训练结果默认保存在文件夹: ./lightning_logs.
会根据你运行的次数自动命名这是版本x。
运行期间可以打开tensorboard 查看运行情况
testing
# test
checkpoint_dir = 'lightning_logs/version_2/checkpoints/'
checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]
model = LightningMNISTClassifier()
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
inputs, labels = next(iter(val_loader))
# inference
outputs = model(inputs)
这里,我只测试一个batch的数据。
测试结果显示
def imshow(inp):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
plt.show()
import torchvision
# 求outputs最大索引
_, preds = torch.max(outputs, dim=1)
# print images and ground truth
imshow(torchvision.utils.make_grid(inputs))
print('GroundTruth:', labels)
print ('Prediction:', preds)
可见,这部分的准确率为100%
tips:
使用lightning模式,用jupyter notebook显示进度条不是很良好,应该要进行相应的设置,我这里不太想去研究了。在终端运行,显示效果更好。
此部分代码很简单,赶快来试一试吧
文章持续更新,可以关注微信公众号【医学图像人工智能实战营】,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~
我是Tina, 我们下篇博客见~
最后,求点赞,评论,收藏。或者一键三连
以上是关于pytorch lightning 手写数字分类实例 的主要内容,如果未能解决你的问题,请参考以下文章
pytorch深度学习实践_p9_多分类问题_pytorch手写实现数字辨识
跟着B站学习pytorch-p13 mnist手写数字图片分类问题
图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)
[Pytorch系列-29]:神经网络基础 - 全连接浅层神经网络实现10分类手写数字识别