GAN 中的损失函数

Posted

技术标签:

【中文标题】GAN 中的损失函数【英文标题】:Loss functions in GANs 【发布时间】:2018-10-03 22:28:57 【问题描述】:

我正在尝试构建一个简单的 mnist GAN,不用多说,它没有用。我已经搜索了很多并修复了我的大部分代码。虽然我不能真正理解损失函数是如何工作的。

这就是我所做的:

loss_d = -tf.reduce_mean(tf.log(discriminator(real_data))) # maximise
loss_g = -tf.reduce_mean(tf.log(discriminator(generator(noise_input), trainable = False))) # maxmize cuz d(g) instead of 1 - d(g)
loss = loss_d + loss_g

train_d = tf.train.AdamOptimizer(learning_rate).minimize(loss_d)
train_g = tf.train.AdamOptimizer(learning_rate).minimize(loss_g)

我得到 -0.0 作为我的损失值。你能解释一下如何处理 吗?

【问题讨论】:

如果不看得太深入,你可能没有设置正确的学习率,导致你的权重爆炸,给你 NaN。 哎呀!我的意思是 -0.0 而不是 NaN。对不起。我会编辑它。 我自己还没有编码,所以我不会回答,但我相信你需要在minimize函数上设置var_list属性。您正在做的是在所有变量上定义两个优化器。如果您的鉴别器和生成器在同一个图表中,那么您正在执行两个相反的更新。鉴别器优化器应该只更新鉴别器的权重,同样地更新生成器的优化器。您应该使用tf.variable_scope 来帮助将变量分为两组。 tf.variable_scope 的替代方法是通过从 tf.keras.Model 继承来组织事物,就像在 eager GAN example 中一样。 (该示例正在急切地执行,但是通过对训练循环进行一些调整,将其切换到图形构建应该相对容易)。然后你会得到每个组件的 .variables 属性。 我主要依赖这个:blog.aylien.com/… 它使用 var_list。 【参考方案1】:

您似乎试图将生成器和鉴别器的损失相加,这是完全错误的! 由于鉴别器使用真实数据和生成数据进行训练,因此您必须创建两种不同的损失,一种用于真实数据,另一种用于传入鉴别器网络的噪声数据(生成)。

尝试如下更改您的代码:

1)

loss_d_real = -tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(real_data),labels= tf.ones_like(discriminator(real_data))))

2)

loss_d_fake=-tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(noise_input),labels= tf.zeros_like(discriminator(real_data))))

那么鉴别器损失将等于= loss_d_real+loss_d_fake。 现在为您的生成器创建损失:

3)

loss_g= tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(genereted_samples), labels=tf.ones_like(genereted_samples)))

【讨论】:

嗨,您能否改进代码格式以澄清您的答案 - 请参阅***.com/help/formatting【参考方案2】:

Maryam 似乎已经确定了您的虚假损失值的原因(即对生成器和鉴别器损失求和)。只是想补充一点,您可能应该为鉴别器选择随机梯度下降优化器来代替 Adam - 这样做可以在玩极小极大游戏时为网络的收敛提供更强的理论保证(来源:https://github.com/soumith/ganhacks)。

【讨论】:

以上是关于GAN 中的损失函数的主要内容,如果未能解决你的问题,请参考以下文章

为啥我们在编译组合 GAN (SRGAN) 网络时使用两个损失

gan算法不包括以下哪个模型

卷积神经网络之GAN(附完整代码)

将 torch.backward() 用于 GAN 生成器时,为啥 Pytorch 中的判别器损失没有变化?

GAN - 生成器损失减少,但鉴别器假损失在初始下降后增加,为啥?

Pytorch深度学习50篇·······第六篇:常见损失函数篇-----BCELoss及其变种