如何在 Tensorflow 中异步更新 GAN 生成器和判别器?

Posted

技术标签:

【中文标题】如何在 Tensorflow 中异步更新 GAN 生成器和判别器?【英文标题】:How to update GAN Generator and Discriminator asynchronously in Tensorflow? 【发布时间】:2018-06-06 22:16:02 【问题描述】:

我想用 Tensorflow 开发一个 GAN,生成器是自动编码器,鉴别器是具有二进制输出的卷积神经网络。开发自动编码器和 CNN 没有问题,但我的想法是为每个组件(判别器和生成器)训练 1 个 epoch,并重复此循环 1000 个 epoch,保持上一个训练 epoch 的结果(权重)为下一个。我该如何操作呢?

【问题讨论】:

【参考方案1】:

如果您有两个 ops,分别称为 train_step_generatortrain_step_discriminator(例如,每个都是 tf.train.AdamOptimizer().minimize(loss) 的形式,每个都有适当的损失),那么您的训练循环应该类似于以下内容结构:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1000):
        if epoch%2 == 0: # train discriminator on even epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
                batch = get_next_batch(batch_size)
                sess.run(train_step_discriminator,feed_dict=z:z_, x:batch)
        else: # train generator on odd epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size)  # this is the input to the generator
                sess.run(train_step_generator,feed_dict=z:z_)

权重将在迭代之间保持不变。

【讨论】:

【参考方案2】:

我解决了这个问题。实际上,我希望自动编码器的输出成为 CNN 的输入,连接 GAN 并以 1:1 的比例更新权重。我注意到我必须特别注意区分生成器和判别器的损失,否则在第二个循环开始时,生成器的张量损失将被浮点数替换,这是判别器生成的最后一个损失。

代码如下:

with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):

这里是生成器训练

    batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)        
    _, l = sess.run([optimizer, loss], feed_dict=X: batch_x.reshape(n,784),
                    Y:batch_y)
    if i % display_step == 0 or i == 1:
        print('Epoch %i: Denoising Loss: %f' % (i, l))

这里生成器的输出将用作判别器的输入

    output=sess.run([decoder_op],feed_dict=X: x_train)
    x_train2=np.array(output).reshape(n,784).astype(np.float64)

这里是判别器训练

    batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
    sess.run(train_op, feed_dict=X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8)
    if i % display_step == 0 or i == 1:
        loss3, acc = sess.run([loss_op2, accuracy], feed_dict=X2: batch_x2,
                                                             Y2: batch_y2,
                                                             keep_prob: 1.0)
        print("Epoch " + str(i) + ", CNN Loss= " + \
              ":.4f".format(loss3) + ", Training Accuracy= " + ":.3f".format(acc))

这样异步更新可以按照 1:1、1:5、5:1(判别器:生成器)的比例或任何其他方式进行操作

【讨论】:

以上是关于如何在 Tensorflow 中异步更新 GAN 生成器和判别器?的主要内容,如果未能解决你的问题,请参考以下文章

如何以Tensorflow为切入点掌握GAN | 迁移学习 | 强化学习

Tensorflow+kerasKeras API两种训练GAN网络的方式

Tensorflow+kerasKeras API两种训练GAN网络的方式

美团云Tensorflow生成对抗网络(Generative Adversarial Networks)实战案例

无法理解 tensorflow 文档中使用的 GAN 模型的损失函数

在TensorFlow中对比两大生成模型:VAE与GAN