Tensorflow 2 + Keras 的知识蒸馏损失

Posted

技术标签:

【中文标题】Tensorflow 2 + Keras 的知识蒸馏损失【英文标题】:Knowledge Distillation loss with Tensorflow 2 + Keras 【发布时间】:2020-03-27 00:08:54 【问题描述】:

我正在尝试实现一个非常简单的 keras 模型,该模型使用来自另一个模型的知识蒸馏 [1]。 粗略地说,我需要用L(y_true, y_pred)+L(y_teacher_pred, y_pred) 替换原始损失L(y_true, y_pred),其中y_teacher_pred 是另一个模型的预测。

我已经尝试过了

def create_student_model_with_distillation(teacher_model):

  inp = tf.keras.layers.Input(shape=(21,))

  model = tf.keras.models.Sequential()
  model.add(inp)

  model.add(...) 
  model.add(tf.keras.layers.Dense(units=1))

  teacher_pred = teacher_model(inp)

  def my_loss(y_true,y_pred):
      loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
      loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
      return loss

  model.compile(loss=my_loss, optimizer='adam')

  return model

但是,当我尝试在我的模型上调用 fit 时,我得到了

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.

我该如何解决这个问题?

参考

[1]https://arxiv.org/abs/1503.02531

【问题讨论】:

问题很可能是teacher_pred = teacher_model(inp)。 Keras 正在尝试通过您的教师模型反向传播梯度。您可以在创建数据集而不是在损失函数中生成教师模型 logits。 【参考方案1】:

其实这篇博文就是回答你的问题:keras blog

但简而言之 - 你应该使用新的 TF2 API 并在 tf.GradientTape() 块之前调用老师的 predict

def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

【讨论】:

以上是关于Tensorflow 2 + Keras 的知识蒸馏损失的主要内容,如果未能解决你的问题,请参考以下文章

无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型

Keras一些常用的API总结

TensorFlow+Keras深度学习人工智能实践应用_林大贵

Keras 2.3.0 发布:支持TensorFlow 2.0!!!!!

TensorFlow2.0TensorFlow 2.0高阶API: Keras—使用Keras基于Squential的序列编排模式创建神经网络过程(附带源码)

Keras TensorFlow 2.0 精华资源