Pytorch Note38 RNN 做图像分类

Posted Real&Love

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch Note38 RNN 做图像分类相关的知识,希望对你有一定的参考价值。

Pytorch Note38 RNN 做图像分类


全部笔记的汇总贴: Pytorch Note 快乐星球

图片分类

RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是作为举例。

首先需要将图片数据转化为一个序列数据,MINST手写数字图片的大小是28x28,那么可以将每张图片看作是长为28的序列,序列中的每个元素的特征维度是28,这样就将图片变成了一个序列。同时考虑循环神经网络的记忆性,所以图片从左往右输入网络的时候,网络可以记忆住前面观察东西,然后与后面部分结合得到最后预测数字的输出结果,理论上是行得通的。

对于一张手写字体的图片,其大小是 28 * 28,我们可以将其看做是一个长为 28 的序列,每个序列的特征都是 28,也就是

这样我们解决了输入序列的问题,对于输出序列怎么办呢?其实非常简单,虽然我们的输出是一个序列,但是我们只需要保留其中一个作为输出结果就可以了,这样的话肯定保留最后一个结果是最好的,因为最后一个结果有前面所有序列的信息,就像下面这样

在数据集中可以先标准化,然后对其可视化

# 定义模型
class rnn_classify(nn.Module):
    def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
        super(rnn_classify, self).__init__()
        self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers, batch_first = True) # 使用两层 lstm 将batch放在前面
        self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
        
    def forward(self, x):
        '''
        x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
        '''
        x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
        out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (batch, 28, hidden_feature)
        out = out[:, -1, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
        out = self.classifier(out) # 得到分类结果
        return out
# 开始训练
from utils import train
Acc,Loss,Lr = train(net, train_loader, test_loader, 10, optimzier, criterion, scheduler, verbose=True)
Epoch [  1/ 10]  Train Loss:1.750572  Train Acc:37.02% Test Loss:0.697916  Test Acc:77.95%  Learning Rate:0.100000	Time 00:14
Epoch [  2/ 10]  Train Loss:0.361676  Train Acc:88.89% Test Loss:0.193300  Test Acc:94.48%  Learning Rate:0.100000	Time 00:13
Epoch [  3/ 10]  Train Loss:0.187760  Train Acc:94.27% Test Loss:0.176153  Test Acc:94.53%  Learning Rate:0.100000	Time 00:14
Epoch [  4/ 10]  Train Loss:0.138180  Train Acc:95.82% Test Loss:0.150296  Test Acc:95.45%  Learning Rate:0.100000	Time 00:13
Epoch [  5/ 10]  Train Loss:0.110974  Train Acc:96.59% Test Loss:0.119391  Test Acc:96.45%  Learning Rate:0.100000	Time 00:14
Epoch [  6/ 10]  Train Loss:0.093230  Train Acc:97.19% Test Loss:0.103071  Test Acc:96.94%  Learning Rate:0.100000	Time 00:15
Epoch [  7/ 10]  Train Loss:0.081078  Train Acc:97.50% Test Loss:0.093530  Test Acc:97.21%  Learning Rate:0.100000	Time 00:15
Epoch [  8/ 10]  Train Loss:0.070568  Train Acc:97.83% Test Loss:0.075076  Test Acc:97.76%  Learning Rate:0.100000	Time 00:14
Epoch [  9/ 10]  Train Loss:0.062708  Train Acc:98.06% Test Loss:0.069279  Test Acc:98.03%  Learning Rate:0.100000	Time 00:14
Epoch [ 10/ 10]  Train Loss:0.056354  Train Acc:98.29% Test Loss:0.065537  Test Acc:98.15%  Learning Rate:0.100000	Time 00:14

可以看到,训练 10 次在简单的 mnist 数据集上也取得的了 98% 的准确率,应该是比较高了。虽然在一个简单的图片数据集上能够达到相对满意的效果,但是循环神经网络还是不适合处理图片类型的数据:

  • 第一个原因是图片并没有很强的序列关系,图片中的信息可以从左往右看,也可以从右往左看,甚至可以跳着随机看,不管是什么样的方式都能够完整地理解图片信息;

  • 第二个原因是循环神经网络传递的时候,必须前面一个数据计算结束才能进行后面一个数据的计算,这对于大图片而言是很慢的,但是卷积神经网络并不需要这样,因为它能够并行,在每一层卷积中,并不需要等待第一个卷积做完才能做第二个卷积,整体是可以同时进行的。

以上是关于Pytorch Note38 RNN 做图像分类的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch实现RNN网络对MNIST字体分类

Pytorch Note40 词嵌入(word embedding)

Pytorch Note37 PyTorch 中的循环神经网络模块

承接TensorFlow深度学习代做pytorch图像处理

干货对抗自编码器PyTorch手把手实战系列——PyTorch实现对抗自编码器

PyTorch-14 使用字符级RNN分类名字(姓名)