生成对抗网络中的鉴别器损失没有改变

Posted

技术标签:

【中文标题】生成对抗网络中的鉴别器损失没有改变【英文标题】:Discriminator Loss Not Changing in Generative Adversarial Network 【发布时间】:2021-02-19 10:45:42 【问题描述】:

我正在尝试使用 pix2pix GAN 生成器和 Unet 作为鉴别器来训练 GAN。但经过一些时期后,我的鉴别器损失停止变化并停留在 5.546 附近。 GAN 训练是好兆头还是坏兆头。

这是我的损失计算:

def discLoss(rValid, rLabel, fValid, fLabel):
    # validity loss
    bce =     tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
    # classifier loss
    scce =     tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    # Loss for real
    real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
    # Loss for fake
    fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
    # Total discriminator loss
    d_loss = (real_dloss + fake_dloss)# / 2
    return d_loss

def generator_loss(disc_generated_output, gen_output, target):
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  LAMBDA = 100
  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

这是我的火车步骤:

def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fImg = generator([img1, label], training=True)
    rValid, rLabel = discriminator(img2, training=True)
    fValid, fLabel = discriminator(fImg, training=True)

    disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
    gen_loss = generator_loss(fValid, fImg, img2)
    # genLoss(label, rValid, rLabel, fValid, fLabel)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()

【问题讨论】:

请直接复制代码,不要链接到图片。这是你的全部代码吗? 【参考方案1】:

这个损失太高了。您需要注意 G 和 D 都以均匀的速度学习。访问此问题和相关链接:How to balance the generator and the discriminator performances in a GAN?

【讨论】:

以上是关于生成对抗网络中的鉴别器损失没有改变的主要内容,如果未能解决你的问题,请参考以下文章

LSGAN:最小二乘生成对抗网络

生成对抗网络需要类别标签吗?

GAN-生成对抗网络-生成手写数字(基于pytorch)

深度学习100例-生成对抗网络(DCGAN)手写数字生成 | 第19天

深度学习100例-生成对抗网络(DCGAN)手写数字生成 | 第19天

如何解释分类生成对抗网络中的损失函数?