在生成对抗网络中如何使用鉴别器的输出训练生成器

Posted

技术标签:

【中文标题】在生成对抗网络中如何使用鉴别器的输出训练生成器【英文标题】:how the generator is trained with the output of discriminator in Generative adversarial Networks 【发布时间】:2017-11-27 11:24:24 【问题描述】:

最近我了解了Generative Adversarial Networks。

为了训练生成器,我对它的学习方式感到困惑。 Here 是 GAN 的实现:

`# train generator
            z = Variable(xp.random.uniform(-1, 1, (batchsize, nz), dtype=np.float32))
            x = gen(z)
            yl = dis(x)
            L_gen = F.softmax_cross_entropy(yl, Variable(xp.zeros(batchsize, dtype=np.int32)))
            L_dis = F.softmax_cross_entropy(yl, Variable(xp.ones(batchsize, dtype=np.int32)))

        # train discriminator

        x2 = Variable(cuda.to_gpu(x2))
        yl2 = dis(x2)
        L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))

        #print "forward done"

        o_gen.zero_grads()
        L_gen.backward()
        o_gen.update()

        o_dis.zero_grads()
        L_dis.backward()
        o_dis.update()`

因此,它会计算论文中提到的生成器的损失。 但是,它会根据鉴别器输出调用生成器后向函数。鉴别器输出只是一个数字(不是数组)。

但我们知道,一般来说,为了训练网络,我们会在最后一层计算损失函数(最后一层输出和实际输出之间的损失),然后计算梯度。例如,如果输出是 64*64,那么我们将其与 64*64 的图像进行比较,然后计算损失并进行反向传播。

但是,在我在生成对抗网络中看到的代码中,我看到它们根据鉴别器输出(只是一个数字)计算生成器的损失,然后调用生成器的反向传播。生成器的最后一层是例如 64*64 像素,但鉴别器损失是 1*1(这与通常的网络不同)所以我不明白它是如何导致生成器被学习和训练的?

我想如果我们附加两个网络(附加生成器和鉴别器)然后调用反向传播但只更新生成器参数,这是有意义的并且应该可以工作。但我在代码中看到的完全不同。

所以我在问怎么可能?

谢谢

【问题讨论】:

你的问题不是很清楚,但是看看这是否有帮助。鉴别器是一个普通的分类器,它以图像为输入,对它的真假进行分类。真实数据来自训练集,假数据来自生成器。所以判别器是根据这两个输入来学习的。对于生成器的情况,它必须欺骗鉴别器,因此将生成器的输出馈送到鉴别器,并通过将鉴别器的输出设置为非假来学习生成器。这里只学习生成器。 谢谢。对不起我的坏问题。我的问题只是关于代码。我清楚地理解算法。我的问题是:为了训练生成器,我们必须将损失从鉴别器反向传播到生成器,但不更新鉴别器参数。但是,在代码中,我只看到他们使用鉴别器输出(损失)并且没有通过鉴别器反向传播,他们将其发送到生成器。我的错误是什么? 我对此并不完全确定,但我明白你的意思。如果反向传播确实通过了鉴别器(因为我们需要扩大规模),这对我来说是有意义的,但是,权重更新仅应用于网络的生成器部分 【参考方案1】:

你说'但是,它会根据鉴别器输出调用生成器后向函数。鉴别器输出只是一个数字(不是数组)',而损失始终是一个标量值。当我们计算两个图像的均方误差时,它也是一个标量值。

L_adversarial = E[log(D(x))]+E[log(1−D(G(z))]

x 来自真实数据分布

z 是由生成器转换的潜在数据分布

回到您的实际问题,判别器网络在最后一层有一个 sigmoid 激活函数,这意味着它的输出范围为 [0,1]。鉴别器试图通过最大化损失函数中添加的两个项来最大化这种损失。第一项的最大值为 0,当 D(x) 为 1 时发生,第二项的最大值也为 0,当 1-D(G(z)) 为 1 时发生,这意味着 D(G(z)) 为 0 . 所以鉴别器试图做一个二进制分类,我最大化这个损失函数,当它被输入 x(真实数据)时它试图输出 1,当它被输入 G(z)(生成的假数据)时输出 0。 但是生成器试图最小化这种损失,换句话说,它试图通过生成与真实样本相似的假样本来欺骗判别器。随着时间的推移,生成器和鉴别器都变得越来越好。这就是 GAN 背后的直觉。

代码在pytorch中

bce_loss = nn.BCELoss() #bce_loss = -ylog(y_hat)-(1-y)log(1-y_hat)[similar to L_adversarial]

Discriminator = ..... #some network   
Generator = ..... #some network

optimizer_generator = ....... #some optimizer for generator network    
optimizer_discriminator = ....... #some optimizer for discriminator network       

z = ...... #some latent data distribution that is transformed by the generator
real = ..... #real data distribution

#####################
#Update Discriminator
#####################
fake = Generator(z)
fake_prediction = Discriminator(fake)
real_prediction = Discriminator(real)
discriminator_loss = bce_loss(fake_prediction,torch.zeros(batch_size))+bce_loss(real_prediction,torch.ones(batch_size))
discriminator_loss.backward()
optimizer_discriminator.step()

#################
#Update Generator
#################
fake = Generator(z)
fake_prediction = Discriminator(fake)
generator_loss = bce_loss(fake_prediction,torch.ones(batch_size))
generator_loss.backward()
optimizer_generator.step()

【讨论】:

以上是关于在生成对抗网络中如何使用鉴别器的输出训练生成器的主要内容,如果未能解决你的问题,请参考以下文章

生成对抗网络需要类别标签吗?

生成对抗网络中的鉴别器损失没有改变

生成对抗网络

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

利用tensorflow训练简单的生成对抗网络GAN