GAN (生成对抗网络) 手写数字图片生成

Posted Real&Love

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN (生成对抗网络) 手写数字图片生成相关的知识,希望对你有一定的参考价值。

GAN (生成对抗网络) 手写数字图片生成


这种训练方式定义了一种全新的网络结构,就是生成对抗网络,也就是 GANs。这一部分,我们会形象地介绍生成对抗网络,以及用代码进行实现,而在书中会更加详细地介绍 GANs 的数学推导。

根据这个名字就可以知道这个网络是由两部分组成的,第一部分是生成,第二部分是对抗。简单来说,就是有一个生成网络和一个判别网络,通过训练让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

可以用这个图来简单的看一看这两个过程

Discriminator Network

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的 label 没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label 是 1 表示真实的;而生成的假的图片的 label 是 0 表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如 logistic 回归,深层网络,卷积神经网络,循环神经网络都可以。

Generator Network

接着我们看看生成网络如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,如上图所示的 D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是 xw+b 将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、转置卷积、池化、激活函数等进行处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片。

这个时候我们如何去训练这个生成器呢?这就需要通过对抗学习,增大判别器判别这个结果为真的概率,通过这个步骤不断调整生成器的参数,希望生成的图片越来越像真的,而在这一步中我们不会更新判别器的参数,因为如果判别器不断被优化,可能生成器无论生成什么样的图片都无法骗过判别器。

生成器的效果可以看看下面的图示

关于生成对抗网络,出现了很多变形,比如 WGAN,LS-GAN 等等,这一节我们只使用 mnist 举一些简单的例子来说明,更复杂的网络结构可以再 github 上找到相应的实现

简单版本的生成对抗网络

通过前面我们知道生成对抗网络有两个部分构成,一个是生成网络,一个是对抗网络,我们首先写一个简单版本的网络结构,生成网络和对抗网络都是简单的多层神经网络

所以,如果我们需要完成一个生成对抗网络,我们需要一个生成器判别器

判别器 Discriminator

判别网络的结构非常简单,就是一个二分类器,结构如下:

  • 全连接(784 -> 1024)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(1024 -> 512)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(512 -> 256)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(256 -> 1)
  • Sigmoid

其中 leakyrelu 是指 f(x) = max( α \\alpha α x, x)

我们判别网络实际上就是一个二分类器,我们需要判断我们的图片是真还是假

class discriminator(nn.Module):
    def __init__(self,input_size):
        super(discriminator,self).__init__()
        
        self.dis = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self,x):
        out = self.dis(x)
        return out

生成器 Generator

接下来我们看看生成网络,生成网络的结构也很简单,就是根据一个随机噪声生成一个和数据维度一样的张量,结构如下:

  • 全连接(噪音维度 -> 256)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(256 -> 512)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(512 -> 1024)
  • leakyrelu, α \\alpha α 是 0.2
  • 全连接(1024 -> 784)
  • tanh 将数据裁剪到 -1 ~ 1 之间
class generator(nn.Module):
    def __init__(self, noise_dim):
        super(generator,self).__init__()
        
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.gen(x)
        return out

超参数设置

对于对抗网络,相当于二分类问题,将真的判别为真的,假的判别为假的,作为辅助,可以参考一下论文中公式

ℓ D = E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \\ell_D = \\mathbb{E}_{x \\sim p_\\text{data}}\\left[\\log D(x)\\right] + \\mathbb{E}_{z \\sim p(z)}\\left[\\log \\left(1-D(G(z))\\right)\\right] D=Expdata[logD(x)]+Ezp(z)[log(1D(G(z)))]
而对于生成网络,需要去骗过对抗网络,也就是将假的也判断为真的,作为辅助,可以参考一下论文中公式

ℓ G = E z ∼ p ( z ) [ log ⁡ D ( G ( z ) ) ] \\ell_G = \\mathbb{E}_{z \\sim p(z)}\\left[\\log D(G(z))\\right] G=Ezp(z)[logD(G(z))]
如果你还记得前面的二分类 loss,那么你就会发现上面这两个公式就是二分类 loss

b c e ( s , y ) = y ∗ log ⁡ ( s ) + ( 1 − y ) ∗ log ⁡ ( 1 − s ) bce(s, y) = y * \\log(s) + (1 - y) * \\log(1 - s) bce(s,y)=ylog(s)+(1y)log(1s)
如果我们把 D(x) 看成真实数据的分类得分,那么 D(G(z)) 就是假数据的分类得分,所以上面判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=3e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=3e-4, betas=(0.5, 0.999))

训练网络生成图片

d_losses = []
g_losses = []
iter_count = 0
for i in range(nepochs):
    for img,_ in train_loader:
        num_img = img.shape[0] # 图片的数量
        real_img = img.view(num_img,-1) 
        real_img = real_img.to(device) # 真实图片
        real_label = Variable(torch.ones(num_img,1)).to(device) # 随机得到单位张量作为真实标签 1
        fake_label = Variable(torch.zeros(num_img,1)).to(device) # 随机得到零张量作为假标签 0
        
        real_out = D(real_img) # 判别真实图片
#         print(real_out.shape)
        
        d_loss_real = criterion(real_out,real_label) # 真实图片的损失
        real_scores = real_out
        
        z = torch.randn(num_img, NOISE).to(device) # 随机生成z NOISE造成的数据
        fake_img = G(z) # 生成假图片
        fake_out = D(fake_img) # 得到D(G(z))
        d_loss_fake = criterion(fake_out,fake_label) # log(1-D(G(z)))
        fake_scores = fake_out
        
        d_loss = d_loss_real + d_loss_fake # 总的损失 x-logD(x) + z-log(1-D(G(z))) 
        d_optimizer.zero_grad() # 梯度归0
        d_loss.backward() # 反向传播
        d_optimizer.step() # 更新生成网络的参数
        
        z = torch.randn(num_img, NOISE).to(device) # 随机生成z NOISE造成的数据
        fake_img = G(z) # 生成图片
        output = D(fake_img) # 经过判别器得到结果
        g_loss = criterion(output, real_label) # 得到假的图片和真实图片的label的loss log(D(G(z)))
        
        g_optimizer.zero_grad() # 归0梯度
        g_loss.backward() # 反向传播
        g_optimizer.step() # 更新生成网络的参数

        if (iter_count % 250 == 0):
#                 display.clear_output(True)
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_loss.data, g_loss.data))
                d_losses.append(d_loss),g_losses.append(g_loss)
                imgs_numpy = deprocess_img(fake_img.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                plt.savefig("images/%d.png" % iter_count) # 每250次保存一次图片
                plt.show()
                print()
        iter_count += 1

Iter: 0, D: 1.364, G:0.6648
在这里插入图片描述

Iter: 250, D: 1.362, G:0.8941
在这里插入图片描述

Iter: 93500, D: 1.331, G:0.8405
在这里插入图片描述
Iter: 93750, D: 1.303, G:0.7253
在这里插入图片描述

我们可以看到,到后面,我们基本可以看到了一个比较好的数字样本图片了,而这些图片都是假的图片,是靠我们的GAN生成出来的,从一开始全是噪声,慢慢的生成这样,还是很不错的,不够迭代了比较长的时间。

我们已经完成了一个简单的生成对抗网络,是不是非常容易呢。但是可以看到效果并不是特别好,迭代的次数有点多,因为我们仅仅使用了简单的多层全连接网络。除了这种最基本的生成对抗网络之外,还有很多生成对抗网络的变式,有结构上的变式,也有 loss 上的变式,不过这些,就留到下次再说吧。

以上是关于GAN (生成对抗网络) 手写数字图片生成的主要内容,如果未能解决你的问题,请参考以下文章

生成对抗网络(GAN)详细介绍及数字手写体生成应用仿真(附代码)

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

万物皆可 GAN生成对抗网络生成手写数字 Part 1

GAN-生成对抗网络-生成手写数字(基于pytorch)

对抗生成网络GAN系列——GAN原理及手写数字生成小案例