对抗性自动编码器系列--对抗自动编码器AAE的原理及实现-从任意随机数重建手写数字

Posted Tina姐

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了对抗性自动编码器系列--对抗自动编码器AAE的原理及实现-从任意随机数重建手写数字相关的知识,希望对你有一定的参考价值。

前言

先来看看实验:

我们使用 MNIST 手写数字,测试通过自动编码器和对抗性自动编码器学习重建恢复效果。

  • 原始图像:
  • 自动编码器重建效果
  • 对抗性自动编码器重建效果
  • 有监督对抗性自动编码器重建效果

虽然这里看到,自动编码器和对抗性自动编码器重建出来的能力差不多,有监督对抗性自动编码器基本上重建出来的图像和输入基本对的上。他们的差别有何不同呢,通过之后几章的学习,大家会有体会。

我们学习自动编码器有什么用?
重建图像本身自然是没有任何意义的,但是能把图像重建出来,说明模型学到了输入图像集的分布和特征。

  • 提取图像特征,特征我们可以拿来做影像组学。
  • 异常检测,图像的分布可以拿来做异常检测。
  • 图像去噪,其中可以使用有噪声的图像生成清晰的无噪声图像。
  • 语义散列可以使用降维来加快信息检索速度。
  • 最近,以对抗方式训练的自动编码器可以用作生成模型(我们稍后会深入探讨)。

具体地, 我们将从以下几部分来介绍:

  1. 自动编码器重建 MNIST 手写数字
  2. 对抗性自动编码器重建 MNIST 手写数字
  3. 半监督自动编码器重建 MNIST 手写数字
  4. 使用自动编码器对 MNIST 进行分类

这是本系列第二个类容:对抗性自动编码器

对抗性自动编码器 AAE

在上一节中,上一节的链接

我们通过将值 (0, 0) 传递给我们训练有素的 Decoder, 恢复出来的图像看起来很模糊,并没有代表一个清晰的数字,这让我们得出这样的结论:Encoder h(潜在空间)的输出在特定空间中没有均匀分布。

所以,我们在这部分的主要目标是强制 Encoder 输出匹配给定的先验分布,这个所需的分布可以是正态(高斯)分布、均匀分布、伽马分布……。这应该会导致 latent space(Encoder输出)均匀分布在给定的先验分布上,这将允许我们的 Decoder 学习从先验到数据分布(在我们的例子中是 MNIST 图像的分布)的映射。

如果你不是很理解这段话,可以看看下面的解释:

假设您在上大学并选择了机器学习(我想不出另一门课程)作为您的课程之一。现在,如果课程讲师不提供教学大纲指南或参考书,那么您将在期末学习什么? (假设您的课程没有帮助)。

你可能会被问到机器学习的任何子领域的问题,甚至问你化妆的东西你知道吗??

如果我们不限制 Encoder 输出遵循某种分布,就会发生这种情况,Decoder 无法学习从任何数字到图像的映射。

但是,如果您获得了适当的教学大纲指南,那么您只需在考试前阅读材料,您就会知道会发生什么。

类似地,如果我们强制 Encoder 输出遵循已知分布(如高斯分布),那么它可以学习 z 以覆盖整个分布并学习没有任何间隙的映射。

我们现在知道 AE 有两个部分,每个部分执行完全相反的任务。

用于从输入中获取 z 的 Encoder,其约束条件是 z 的维度应小于输入维度,其次,接收此 z 并尝试重构原始图像的 Decoder。

让我们看看当我们之前实现我们的 AE Encoder 输出是如何分布的(第 1 部分):

从分布图(右),我们可以清楚地看到我们的编码器的输出分布到处都是。最初,分布似乎以 0 为中心,大多数值为负。在训练的后期阶段,与正样本相比,负样本分布更远离 0(此外,如果我们再次运行实验,我们甚至可能不会得到相同的分布)。这会导致编码器分布出现大量间隙,如果我们想将Decoder用作生成模型,这不是一件好事。

但是,为什么在我们的编码器分布中存在这些间隙是件坏事 如果我们将落入此间隙的输入提供给经过训练的Decoder,那么它会给出看起来很奇怪的图像,这些图像在其输出处不代表数字。

我们现在看看可以解决上述一些问题的对抗性自动编码器 AAE。

Adversarial autoencoder(AAE) 与 Autoencoder(AE) 非常相似,但 encoder 以对抗方式进行训练以强制其输出所需的分布。

理解 AAE 需要具备生成对抗性网络 (GAN) 的知识。可以查看GAN 的基础知识

如果您已经了解 GAN,这里有一个快速回顾(如果您还记得接下来的两点,请随时跳过本节):

discriminator 判别器

generator 生成器

  • GAN 有两个神经网络,一个生成器和一个判别器。
  • 生成器,很好地生成假图像。我们训练鉴别器将我们的数据集中的真实图像与生成器生成的假图像区分开来。
  • 生成器最初会产生一些随机噪声(因为它的权重是随机的)。在训练我们的鉴别器来区分这些随机噪声和真实图像后,我们将我们的生成器连接到我们的鉴别器,并且只通过生成器反向传播。
  • 我们将再次训练我们的鉴别器来区分来自我们的生成器的新假图像和来自我们数据库的真实图像。然后训练生成器以生成更好看的假图像。
  • 我们将继续这个过程,直到生成器非常擅长生成假图像,以致判别器不再能够区分真图像和假图像。
  • 最后,我们将得到一个生成器,该生成器可以在给定一组随机数字作为输入的情况下生成真实的假图像。

对抗性自动编码器 AAE 的流程图:

  • x → Input image
  • q(z/x) → Encoder output given input x
  • z → Latent code (fake input), z is drawn from q(z/x)
  • z’ → Real input with the required distribution
  • p(x/z) →Decoder output given z
  • D() → Discriminator
  • x_ →Reconstructed image

同样,我们的主要任务是强制 encoder 输出具有给定先验分布的值(这可以是正态分布、伽马分布)。我们将使用 encoder 作为我们的生成器、Discriminator 鉴别器来判断样本是来自先验分布 (p(z)) 还是来自 encoder (z) 。

为了了解如何使用这种架构对 encoder 输出施加先验分布,让我们看看我们如何训练 AAE。

AAE 的训练

训练 AAE 有两个阶段:

  • 重建阶段
    我们将训练 encoder 和 Decoder 以最小化重建损失(输入和 Decoder 输出图像之间的均方误差,查看第 1 部分了解更多详细信息)。忘记鉴别器甚至存在于这个阶段(我已经把这个阶段不需要的部分灰化了)。

    像第一部分一样,我们将输入传递给 Encoder,将为我们提供潜在代码,稍后,我们将这个 z 传递给 Decoder 以获取输入图像。我们将通过 Encoder和 Decoder 的权重进行反向传播,以便减少重建损失。

  • 正则化阶段
    在这个阶段,我们必须训练 Discriminator(鉴别器)和生成器(指的是 Encoder)。只是忘记了 Decoder 的存在。

    首先,我们训练鉴别器对 Encoder 输出(z)和一些随机输入(z’,某种分布)进行分类。例如,随机输入可以呈正态分布,均值为 0,标准差为 5。

因此,如果我们传入具有所需分布(从该分布中采样的z’定义为真实值)的随机输入,鉴别器应该给我们一个输出 1,而当我们传入 Encoder 输出时,它应该给我们一个输出 0(假值)。

因此,Encoder 输出和鉴别器的输入应该具有相同的大小。

现在,理论部分已经结束,让我们来看看如何实现这一点。

Encoder

Decoder

Discriminator

这部分的 Encoder 和 Decoder 和第一部分是一摸一样的,如果你已经实现了第一部分的代码,直接 copy 过来就好了。Discriminator 也是一个很简单的分类网络,输出一个节点(表示概率)。

我们现在知道训练 AAE 有两个部分,首先是重建阶段(我们将训练我们的 AE 来重建输入)和正则化阶段(首先训练 Discriminator,然后是 Encoder)。

损失函数

  • AE loss: 同第一部分一样,使用 MSE loss.
  • 对抗loss: 使用二进制交叉熵
    • 判别器:训练判别器的时候,Encoder 的输出作为输入时,对应的标签应该是 0, 从正态分布中采样的输入,对应的标签为1.
    • 生成器:也就是 Encoder。生成器的loss跟判别器相反,Encoder 的输出作为输入时,对应的标签应该是 1。不计算 从正态分布中采样的输入的损失。

看代码:

generator_loss = bce_loss(d_fake, target=torch.ones_like(d_fake))
            # encoder: d_fake值越接近1越好
discriminator_loss = bce_loss(d_fake, target=torch.zeros_like(d_fake)) +\\
                                 bce_loss(d_real, target=torch.ones_like(d_real))
            # discriminator: d_fake值越接近0越好, d_real值越接近1越好

这部分较第一部分多了一个 discriminator 和 GAN 的训练。如果对 GAN 不是很熟悉的,可能在训练上会有难度,但这是一个非常基础的 GAN 训练,我也会提供 Tensorflow 和 Pytorch 代码,我的 torch 代码也写得非常简单明了,认真读代码还是可以很快入门的。

AAE 结果



原始图像

重建图像

最后,测试随机输入的重建效果,我们可以将属于所需分布的随机输入传递给我们训练有素的 Decoder 并将其用作生成器(我知道,我一直将 Encoder 称为生成器)它正在为鉴别器生成假输入。

将 z 值从 (-10, -10) 到 (10, 10) 传递给 Decoder,并将其输出存储为数字分布方式:

这部分实验代码

虽然,我们的网络已经可以从满足某种分布的随机数生成数字了,但是从图上来看,根本不知道他会生成0-9的哪个数字。如果我们想采样一个随机数,想要它生成特定数字,怎么做?

如果感兴趣,欢迎进入下一章节,有监督的对抗自动编码器。

文章持续更新,可以关注微信公众号【医学图像人工智能实战营】获取最新动态,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~

我是Tina, 我们下篇博客见~

白天工作晚上写文,呕心沥血

觉得写的不错的话最后,求点赞,评论,收藏。或者一键三连

以上是关于对抗性自动编码器系列--对抗自动编码器AAE的原理及实现-从任意随机数重建手写数字的主要内容,如果未能解决你的问题,请参考以下文章

对抗性自动编码器系列--自动编码器AutoEncoder的原理及实现-手写数字的重建

AAE对抗自编码器/GAN与VAE的区别

对抗性自动编码器无法正常工作且无法正确学习

干货对抗自编码器PyTorch手把手实战系列——PyTorch实现对抗自编码器

GAN生成式对抗网络

GAN生成式对抗网络