Keras 中的 .fit() 方法触发损失函数多少次
Posted
技术标签:
【中文标题】Keras 中的 .fit() 方法触发损失函数多少次【英文标题】:How many times the loss function is triggered from .fit() method in Keras 【发布时间】:2021-08-20 06:33:51 【问题描述】:我正在尝试在自定义损失函数中进行一些自定义计算。但是当我从自定义损失函数中记录语句时,似乎自定义损失函数只被调用一次(在 .fit() 方法的开头)。
损失函数示例:
def loss(y_true, y_pred):
print("--- Starting of the loss function ---")
print(y_true)
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
print("--- Ending of the loss function ---")
return loss
使用回调检查批处理的开始和结束时间:
class monitor(Callback):
def on_batch_begin(self, batch, logs=None):
print("\n >> Starting a new batch (batch index) :: ", batch)
def on_batch_end(self, batch, logs=None):
print(">> Ending a batch (batch index) :: ", batch)
.fit() 方法用作:
history = model.fit(
x=[inputs],
y=[outputs],
shuffle=False,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCH,
verbose=1,
callbacks=[monitor()]
)
使用的参数:
BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)
还有输出:
Epoch 1/3
>> Starting a new batch (batch index) :: 0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307
Epoch 2/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246
Epoch 3/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219
为什么自定义损失函数只在开始时调用,而不是每次批量计算都调用?我还想知道何时调用/触发损失函数?
【问题讨论】:
仔细观察输出,每批都会调用损失函数。 @ShubhamPanchal 那为什么print("--- Starting of the loss function ---")
和print("--- Ending of the loss function ---")
语句不是每次调用都打印出来的?
我明白你的观点@Rahul。我不知道为什么您没有看到所有时代都印有这些声明。然而,在每个时期都会打印损失值这一事实让我认为损失是正确计算的。打印可能无法正常工作。您可以尝试改用日志记录
【参考方案1】:
损失函数调试消息仅在 训练开始。
这是因为为了提高性能,您的损失函数在内部被转换为张量流图,而 python 打印函数仅在您的函数被跟踪时才有效。即它仅在训练开始时打印,这意味着当时正在跟踪您的损失函数。更多信息请参考以下页面:https://www.tensorflow.org/guide/function
简答:要正确打印,请使用 tf.print() 而不是 print()
我也想知道损失函数什么时候被调用/触发?
使用 tf.print() 后,调试信息将正确打印。您将看到您的损失函数在每一步中至少被调用一次,以获取损失值和梯度。
【讨论】:
以上是关于Keras 中的 .fit() 方法触发损失函数多少次的主要内容,如果未能解决你的问题,请参考以下文章
使用 fit_generator 时 Keras 中的嘈杂验证损失
如何在 tf.keras 自定义损失函数中触发 python 函数?