在 Keras 训练期间动态更改损失函数,无需重新编译优化器等其他模型属性

Posted

技术标签:

【中文标题】在 Keras 训练期间动态更改损失函数,无需重新编译优化器等其他模型属性【英文标题】:Change loss function dynamically during training in Keras, without recompiling other model properties like optimizer 【发布时间】:2019-09-22 13:16:45 【问题描述】:

是否可以在回调中设置model.loss 而无需在之后重新编译model.compile(...)(此后优化器状态被重置),而只需重新编译model.loss,例如:

class NewCallback(Callback):

        def __init__(self):
            super(NewCallback,self).__init__()

        def on_epoch_end(self, epoch, logs=):
            self.model.loss=[loss_wrapper(t_change, current_epoch=epoch)]
            self.model.compile_only_loss() # is there a version or hack of 
                                           # model.compile(...) like this?

要扩展更多关于 *** 的示例:

要实现一个依赖于纪元数的损失函数,比如(如this *** question):

def loss_wrapper(t_change, current_epoch):
    def custom_loss(y_true, y_pred):
        c_epoch = K.get_value(current_epoch)
        if c_epoch < t_change:
            # compute loss_1
        else:
            # compute loss_2
    return custom_loss

其中“current_epoch”是一个使用回调更新的 Keras 变量:

current_epoch = K.variable(0.)
model.compile(optimizer=opt, loss=loss_wrapper(5, current_epoch), 
metrics=...)

class NewCallback(Callback):
    def __init__(self, current_epoch):
        self.current_epoch = current_epoch

    def on_epoch_end(self, epoch, logs=):
        K.set_value(self.current_epoch, epoch)

基本上可以将 python 代码转换为后端函数的组合,以便 loss 如下工作:

def loss_wrapper(t_change, current_epoch):
    def custom_loss(y_true, y_pred):
        # compute loss_1 and loss_2
        bool_case_1=K.less(current_epoch,t_change)
        num_case_1=K.cast(bool_case_1,"float32")
        loss = (num_case_1)*loss_1 + (1-num_case_1)*loss_2
        return loss
    return custom_loss
it works.

我对这些 hack 不满意,并且想知道,是否可以在回调中设置 model.loss 而无需在之后重新编译 model.compile(...)(此后优化器状态被重置),而只需重新编译 model.loss

【问题讨论】:

你解决了吗?您需要保持整个优化器状态还是只保留权重?如果只是权重,也许保存它们,重新编译,然后加载它们。有 Model.load_weights(..., by_name=True) 可以加载到与保存它们不同的模型中。还有像***.com/questions/49503748/… 这样的保存/加载整个状态,但我不确定它是否允许您更改架构。 您找到解决方案了吗?我也有同样的问题。 我认为使用动态计算图或eager execution 模式与tf 2.0 将解决这个问题eager execution 根据您的最后一种方法,我不觉得有一个单一的损失函数太难了。您也可以使用 model.add_loss() 来做类似的事情,而无需使用包装器。 【参考方案1】:

我希望您现在已经找到了解决问题的方法,但是使用 tensorflow 我认为您可以通过构建自定义训练循环 (here is the doc) 来解决这个问题。这不会按照您的要求覆盖损失属性,但是您可能可以实现您想要的。

示例

初始化变量

使用模型和数据集修改文档中的示例:

inputs = tf.keras.Input(shape=(784,), name="digits")
x1 = tf.keras.layers.Dense(64, activation="relu")(inputs)
x2 = tf.keras.layers.Dense(64, activation="relu")(x1)
outputs = tf.keras.layers.Dense(10, name="predictions")(x2)
model = tf.keras.Model(inputs=inputs, outputs=outputs)


# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

我们可以定义我们的两个损失函数(我选择的两个从科学角度来看没有意义,但可以让我们检查代码是否有效)

# Instantiate an optimizer.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_1 = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_2 = lambda y_true, y_pred: -1 * loss_1(y_true, y_pred)

训练循环

然后我们可以执行我们的自定义训练循环:

epochs = 10
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        loss_fn = loss_1 if epoch % 2 else loss_2
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
         # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * 64))

我们检查输出是我们想要的(交替的正负损失)

Start of epoch 0
Training loss (for one batch) at step 0: -96.1003
Seen so far: 64 samples
Training loss (for one batch) at step 200: -3383849.5000
Seen so far: 12864 samples
Training loss (for one batch) at step 400: -40419124.0000
Seen so far: 25664 samples
Training loss (for one batch) at step 600: -149133008.0000
Seen so far: 38464 samples
Training loss (for one batch) at step 800: -328322816.0000
Seen so far: 51264 samples

Start of epoch 1
Training loss (for one batch) at step 0: 580457984.0000
Seen so far: 64 samples
Training loss (for one batch) at step 200: 297710528.0000
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 213328544.0000
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 159328976.0000
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 105737024.0000
Seen so far: 51264 samples

缺点和进一步改进

这样编写自定义循环的问题在于,您将失去 keras 的 fit 方法的便利性。我认为您可以通过定义自定义模型并覆盖train_step 来管理此问题,如文档中的here 所示

如果您确实需要更改模型的loss 属性,可以使用keras.engine.compile_utils.LossesContainer 设置compiled_loss 属性(here 是参考)并将model.train_function 设置为model.make_train_function()(以便将新的损失考虑在内)。

【讨论】:

以上是关于在 Keras 训练期间动态更改损失函数,无需重新编译优化器等其他模型属性的主要内容,如果未能解决你的问题,请参考以下文章

在训练期间改变损失函数

来自 Keras 序列模型训练的 Nan 损失

keras:如何在训练前获得初始损失函数值

可以在 keras 损失函数中直接访问中间层吗?

自定义 keras 损失函数二元交叉熵给出不正确的结果

如何在 Keras 中创建一个依赖于纪元数的损失函数参数?