如何避免在 TensorFlow 中添加重复的集合?

Posted

技术标签:

【中文标题】如何避免在 TensorFlow 中添加重复的集合?【英文标题】:How do I avoid adding duplicate collections in Tensorflow? 【发布时间】:2017-10-20 04:07:59 【问题描述】:

Tensorflow 初学者在这里。

我有一段代码(一起)对一组图像进行训练和验证。每隔一段时间,在训练循环中,我执行验证并从验证数据集中获取损失。我总结了结果并使用 tensorboard 来查看我的可视化。

我的问题是,我计算了我的损失两次,我不应该这样做。我的代码会清楚地说明发生了什么

获取一些已经分成训练集和验证集的图像,同时构建神经网络:

        images, labels = (
            self._input_pipeline(filenames, self.model_config.BATCH_SIZE))
        v_images, v_labels = (
            self._input_pipeline(v_filenames, self.model_config.BATCH_SIZE))
        logits = self.build_nets(images)
        tf.get_variable_scope().reuse_variables()
        v_logits = self.build_nets(v_images)

设置损失函数:

        _ = self.set_loss(logits, labels)
        validation_step = self.set_loss(v_logits, v_labels)

这就是 set_loss 的样子:

def set_loss(self, y, y_):
    cross_entropy_sum = (
                tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=y_)))
    tf.add_to_collection('cross_entropy_loss', cross_entropy_sum)
    return tf.losses.get_losses()

正在发生的问题是 cross_entropy_loss 被两次添加到集合中,从而给了我交叉熵损失的 2 倍输出

集合“cross_entropy_loss”在主例程中用于计算 cross_entropy_total:

        get_cross_entropy = tf.get_collection('cross_entropy_loss')
        cross_entropy_total = tf.add_n(get_cross_entropy, name='cross_entropy_loss_across_all_gpus')
        tf.summary.scalar("cross entropy loss", cross_entropy_total)

单个摘要操作生成摘要:

        summary_op = tf.summary.merge_all()

训练步骤如下所示:

        train_step = (
            tf.train.GradientDescentOptimizer(model_config.INITIAL_LEARNING_RATE).minimize(cross_entropy_total))

这是最后一块,运行训练块和验证块并写出摘要

                _, cross_entropy = sess.run([train_step, cross_entropy_total])
                if step % self.model_config.SUMMARY_EVERY == 0:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    #validation
                    _, cross_entropy = sess.run([validation_step, cross_entropy_total])
                    v_summary_str = sess.run(summary_op)
                    v_summary_writer.add_summary(v_summary_str, step)                        

那么有人可以帮助我如何避免计算 cross_entropy_total 两次吗? 例如,如果没有执行验证的损失是 100,如果我插入上面显示的验证部分,它会变成 200

【问题讨论】:

我回答你的问题了吗? 我想知道为什么你应该将验证损失添加到 cross_entropy_loss 集合中,然后通过它更新权重。如果该步骤用于验证,则不应跳过损失以将其添加到仅包含训练损失的集合中。 【参考方案1】:

通常,人们通过feed_dict 参数向sess.run 提供训练和验证数据。这样,“图”就不依赖于数据集部分。

另外,我不清楚为什么您甚至需要在这里使用“集合”。您可以将cross_entropy_sum = tf.reduce_sum ... 提供给minimize

【讨论】:

@wrecktangle tensorflow.org/get_started/mnist/beginners (Crtl-F feed_dict) 另外我想最小化 cross_entropy_total,这是跨多个 gpu 的 cross_entropy_sum,那么如果没有集合,我该如何做到这一点?

以上是关于如何避免在 TensorFlow 中添加重复的集合?的主要内容,如果未能解决你的问题,请参考以下文章

Mongoose:如何避免重复文件?

对于java中如何去除重复的数据

如何避免在 SQLAlchemy - python 的多对多关系表中添加重复项?

从日期中添加或减去天数时如何避免周末或节假日[重复]

如何在c#中将数据添加到集合列表中[重复]

Vue Firestore 查询。 Firestore更新数据时如何避免重复?