基于Pytorch的神经网络之autoencoder

Posted ZDDWLIG

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于Pytorch的神经网络之autoencoder相关的知识,希望对你有一定的参考价值。

目录

1. 引言

2.结构

3.搭建网络

4.代码


1. 引言

今天我们来学习一种在压缩数据方面有较好效果的神经网络:自编码(autoencoder)。

2.结构

自编码的主要思想就是将数据先不断encode进行降维提取其中的关键信息,再decode解码成新的信息,我们的目标是要使我们生成的信息和原信息尽可能相似,有点像我们做题一样,先刷大量的题,提取其中的关键解题思路,遇到同类型的题目就会写了,基本结构如下

 这个网络先将信息不断压缩,再解压,对比原始数据和新数据的差别,再反向传播修正参数,最后输出的新数据就会越来越接近原始数据,最后最中间的一层就是这组信息的关键特征。

3.搭建网络

我们以输出手写数字为例,我们还是只看网络的搭建过程:

#搭建网络
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),  
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),      
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

 我们需要构建编码和解码两个网络,编码部分就是将数据一层一层压缩,解码部分就是对应着一层层解压,中间加入非线性函数,最后再加一个激活函数,但是要确保数据的范围与原来一样,因为我们的目标是生成与原数据一样的数据,最后前向传播先编码再解码即可。

 下面是效果

开始:

 

中间:

 

最后:

 可以看出输出图片与原始图片越来越接近了。

最后的损失:

loss:0.33

4.代码

#调库
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
import numpy as np

#超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005         # learning rate
DOWNLOAD_MNIST = False
N_TEST_IMG = 5

#下载数据
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,                                    
    transform=torchvision.transforms.ToTensor(),                                 
    download=DOWNLOAD_MNIST,                       
)

#小批数据
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

#搭建网络
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),  
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),      
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


autoencoder = AutoEncoder()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()   

view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())

#训练并绘图
for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)  
        b_y = x.view(-1, 28*28)   

        encoded, decoded = autoencoder(b_x)

        loss = loss_func(decoded, b_y)   
        optimizer.zero_grad()              
        loss.backward()                   
        optimizer.step()                   

        if step % 100 == 0:
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

            _, decoded_data = autoencoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
                a[1][i].set_xticks(()); a[1][i].set_yticks(())
            plt.draw(); plt.pause(0.05)

plt.ioff()
plt.show()

以上是关于基于Pytorch的神经网络之autoencoder的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch Note43 自动编码器(Autoencoder)

PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码

PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码

PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码

Pytorch中的自编码(autoencoder)

PyTorch - MAE(Masked Autoencoders)推理脚本