PyTorch实现简单的变分自动编码器VAE

Posted picassooo

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch实现简单的变分自动编码器VAE相关的知识,希望对你有一定的参考价值。

      在上一篇博客中我们介绍并实现了自动编码器,本文将用PyTorch实现变分自动编码器(Variational AutoEncoder, VAE)。自动变分编码器原理与一般的自动编码器的区别在于需要在编码过程增加一点限制,迫使它生成的隐含向量能够粗略的遵循标准正态分布。这样一来,当需要生成一张新图片时,只需要给解码器一个标准正态分布的隐含随机向量就可以了。

      在实际操作中,实际上不是生成一个隐含向量,而是生成两个向量:一个表示均值,一个表示标准差,然后通过这两个统计量合成隐含向量,用一个标准正态分布先乘标准差再加上均值就行了。具体关于变分自动编码器的内容,可参考廖星宇的《深度学习之PyTorch》的第六章,下面的代码也是来自这个资料,但本文对原代码做了一点改动。

import os
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image

# Hyper parameters
EPOCH = 1
LR = 1e-3
BATCHSIZE = 128

im_tfs = tfs.Compose([
    tfs.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                       # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    tfs.Normalize([0.5], [0.5])   # 把[0.0, 1.0]的数据扩大范围到[-1., 1]
])

train_set = MNIST(
    root=‘/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/‘,   # mnist has been downloaded before, use it directly
    train=True,
    transform=im_tfs,
)
train_loader = DataLoader(train_set, batch_size=BATCHSIZE, shuffle=True)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)   # mean
        self.fc22 = nn.Linear(400, 20)   # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()                     # 矩阵点对点相乘之后再把这些元素作为e的指数
        eps = torch.FloatTensor(std.size()).normal_()    # 生成随机数组
        if torch.cuda.is_available():
            eps = eps.cuda()
        return eps.mul(std).add_(mu)    # 用一个标准正态分布乘标准差,再加上均值,使隐含向量变为正太分布

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.tanh(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)          # 编码
        z = self.reparametrize(mu, logvar)   # 重新参数化成正态分布
        return self.decode(z), mu, logvar    # 解码,同时输出均值方差


net = VAE()  # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()

reconstruction_function = nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD


optimizer = torch.optim.Adam(net.parameters(), lr=LR)


def to_img(x):   # x shape (bachsize, 28*28), x中每个像素点的大小范围[-1., 1.]
    ‘‘‘
    定义一个函数将最后的结果转换回图片
    ‘‘‘
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x


for epoch in range(EPOCH):
    for iteration, (im, y) in enumerate(train_loader):
        im = im.view(im.shape[0], -1)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0]   # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iteration % 100 == 0:
            print(‘epoch: {:2d} | iteration: {:4d} | Loss: {:.4f}‘.format(epoch, iteration, loss.data.numpy()))
            save = to_img(recon_im.cpu().data)
            if not os.path.exists(‘./vae_img‘):
                os.mkdir(‘./vae_img‘)
            save_image(save, ‘./vae_img/image_{}_{}.png‘.format(epoch, iteration))


# test
code = torch.randn(1, 20)   # 随机给一个符合正态分布的张量
out = net.decode(code)
img = to_img(out)
save_image(img, ‘./vae_img/test_img.png‘)

以上是关于PyTorch实现简单的变分自动编码器VAE的主要内容,如果未能解决你的问题,请参考以下文章

单指标时间序列异常检测——基于重构概率的变分自编码(VAE)代码实现(详细解释)

单指标时间序列异常检测——基于重构概率的变分自编码(VAE)代码实现(详细解释)

单指标时间序列异常检测——基于重构概率的变分自编码(VAE)代码实现(详细解释)

pytorch 笔记:VAE 变分自编码器

pytorch 笔记:VAE 变分自编码器

python 一个简单的变分自动编码器实现