Keras 中的 .fit() 方法触发损失函数多少次

Posted

技术标签:

【中文标题】Keras 中的 .fit() 方法触发损失函数多少次【英文标题】:How many times the loss function is triggered from .fit() method in Keras 【发布时间】:2021-08-20 06:33:51 【问题描述】:

我正在尝试在自定义损失函数中进行一些自定义计算。但是当我从自定义损失函数中记录语句时,似乎自定义损失函数只被调用一次(在 .fit() 方法的开头)。

损失函数示例:

def loss(y_true, y_pred):
    print("--- Starting of the loss function ---")
    print(y_true)
    loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
    print("--- Ending of the loss function ---")
    return loss

使用回调检查批处理的开始和结束时间:

class monitor(Callback):
    def on_batch_begin(self, batch, logs=None):
        print("\n >> Starting a new batch (batch index) :: ", batch)

    def on_batch_end(self, batch, logs=None):
        print(">> Ending a batch (batch index) :: ", batch)

.fit() 方法用作:

history = model.fit(
    x=[inputs],
    y=[outputs],
    shuffle=False,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCH,
    verbose=1,
    callbacks=[monitor()]
)

使用的参数:

BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)

还有输出:

Epoch 1/3
 >> Starting a new batch (batch index) ::  0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307

Epoch 2/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246

Epoch 3/3
 >> Starting a new batch (batch index) ::  0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
 >> Ending a batch (batch index) ::  0

 >> Starting a new batch (batch index) ::  1
 >> Ending a batch (batch index) ::  1

 >> Starting a new batch (batch index) ::  2
 >> Ending a batch (batch index) ::  2

 >> Starting a new batch (batch index) ::  3
 >> Ending a batch (batch index) ::  3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219

为什么自定义损失函数只在开始时调用,而不是每次批量计算都调用?我还想知道何时调用/触发损失函数?

【问题讨论】:

仔细观察输出,每批都会调用损失函数。 @ShubhamPanchal 那为什么print("--- Starting of the loss function ---")print("--- Ending of the loss function ---") 语句不是每次调用都打印出来的? 我明白你的观点@Rahul。我不知道为什么您没有看到所有时代都印有这些声明。然而,在每个时期都会打印损失值这一事实让我认为损失是正确计算的。打印可能无法正常工作。您可以尝试改用日志记录 【参考方案1】:

损失函数调试消息仅在 训练开始。

这是因为为了提高性能,您的损失函数在内部被转换为张量流图,而 python 打印函数仅在您的函数被跟踪时才有效。即它仅在训练开始时打印,这意味着当时正在跟踪您的损失函数。更多信息请参考以下页面:https://www.tensorflow.org/guide/function

简答:要正确打印,请使用 tf.print() 而不是 print()

我也想知道损失函数什么时候被调用/触发?

使用 tf.print() 后,调试信息将正确打印。您将看到您的损失函数在每一步中至少被调用一次,以获取损失值和梯度。

【讨论】:

以上是关于Keras 中的 .fit() 方法触发损失函数多少次的主要内容,如果未能解决你的问题,请参考以下文章

使用 fit_generator 时 Keras 中的嘈杂验证损失

在 keras 中制作自定义损失函数

Keras中多输出模型的验证损失和验证数据

如何在 tf.keras 自定义损失函数中触发 python 函数?

Keras 自定义loss函数 focal loss + triplet loss

多输出 Keras 的回归损失函数