无法理解 tensorflow 文档中使用的 GAN 模型的损失函数

Posted

技术标签:

【中文标题】无法理解 tensorflow 文档中使用的 GAN 模型的损失函数【英文标题】:Can't understand the loss functions for the GAN model used in the tensorflow documentation 【发布时间】:2020-08-12 18:31:29 【问题描述】:

我无法理解 TensorFlow 文档中 GAN 模型中的损失函数。为什么将tf.ones_like() 用于real_losstf.zeros_like() 用于假输出??

def discriminator_loss(real_output,fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output),real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
  total_loss = real_loss + fake_loss
  return total_loss

【问题讨论】:

【参考方案1】:

我们有以下损失函数,我们需要以最小最大方式(或最小最大,如果你想这样称呼它)最小化。

    generator_loss = -log(generated_labels) discriminator_loss = -log(real_labels) - log(1 - generated_labels)

其中real_output = real_labels 和fake_output = generated_labels。

现在,考虑到这一点,让我们看看 TensorFlow 文档中的代码 sn-p 代表什么:

real_loss = cross_entropy(tf.ones_like(real_output), real_output) 评估为 real_loss = -1 * log(real_output) - (1 - 1) * log(1 - real_output) = -log(real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output) 评估为 fake_loss = -0 * log(fake_output) - (1 - 0) * log(1 - fake_output) = -log(1 - fake_output) total_loss = real_loss + fake_loss 评估为 total_loss = -log(real_output) - log(1 - fake_output)

显然,我们得到了我们想要最小化的 mini-max 游戏中判别器的损失函数。

【讨论】:

以上是关于无法理解 tensorflow 文档中使用的 GAN 模型的损失函数的主要内容,如果未能解决你的问题,请参考以下文章

无法匹配 GA 高级细分和 BigQuery 结果

如果我想使用无法通过 TensorFlow 加载到内存中的大型数据集,我该怎么办?

TensorFLow1.3文档中文翻译之1.0.0安装

在张量流中读取数据

tf.nn.conv2d 在 tensorflow 中做了啥?

使用 OCR 引擎 tesseract 无法理解提取文档中的坐标