生成对抗网络中的鉴别器损失没有改变
Posted
技术标签:
【中文标题】生成对抗网络中的鉴别器损失没有改变【英文标题】:Discriminator Loss Not Changing in Generative Adversarial Network 【发布时间】:2021-02-19 10:45:42 【问题描述】:我正在尝试使用 pix2pix GAN 生成器和 Unet 作为鉴别器来训练 GAN。但经过一些时期后,我的鉴别器损失停止变化并停留在 5.546 附近。 GAN 训练是好兆头还是坏兆头。
这是我的损失计算:
def discLoss(rValid, rLabel, fValid, fLabel):
# validity loss
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
# classifier loss
scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Loss for real
real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
# Loss for fake
fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
# Total discriminator loss
d_loss = (real_dloss + fake_dloss)# / 2
return d_loss
def generator_loss(disc_generated_output, gen_output, target):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
LAMBDA = 100
# mean absolute error
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss
这是我的火车步骤:
def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fImg = generator([img1, label], training=True)
rValid, rLabel = discriminator(img2, training=True)
fValid, fLabel = discriminator(fImg, training=True)
disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
gen_loss = generator_loss(fValid, fImg, img2)
# genLoss(label, rValid, rLabel, fValid, fLabel)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()
【问题讨论】:
请直接复制代码,不要链接到图片。这是你的全部代码吗? 【参考方案1】:这个损失太高了。您需要注意 G 和 D 都以均匀的速度学习。访问此问题和相关链接:How to balance the generator and the discriminator performances in a GAN?
【讨论】:
以上是关于生成对抗网络中的鉴别器损失没有改变的主要内容,如果未能解决你的问题,请参考以下文章
深度学习100例-生成对抗网络(DCGAN)手写数字生成 | 第19天