在 TensorFlow 2.0 的自定义训练循环中应用回调
Posted
技术标签:
【中文标题】在 TensorFlow 2.0 的自定义训练循环中应用回调【英文标题】:Applying callbacks in a custom training loop in Tensorflow 2.0 【发布时间】:2020-04-13 18:33:38 【问题描述】:我正在使用 Tensorflow DCGAN 实施指南中提供的代码编写自定义训练循环。我想在训练循环中添加回调。在 Keras 中,我知道我们将它们作为参数传递给 'fit' 方法,但找不到有关如何在自定义训练循环中使用这些回调的资源。我正在从 Tensorflow 文档中添加自定义训练循环的代码:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch is sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
【问题讨论】:
【参考方案1】:我自己也遇到过这个问题:(1)我想使用自定义训练循环; (2) 我不想失去 Keras 在回调方面给我的花里胡哨; (3) 我不想自己重新实现它们。 Tensorflow 的设计理念是允许开发人员逐渐选择加入其更底层的 API。正如@HyeonPhilYoun 在下面的评论中指出的那样,tf.keras.callbacks.Callback
的官方文档给出了我们正在寻找的示例。
以下内容对我有用,但可以通过逆向工程改进tf.keras.Model
。
诀窍是使用tf.keras.callbacks.CallbackList
,然后从您的自定义训练循环中手动触发其生命周期事件。此示例使用tqdm
提供有吸引力的进度条,但CallbackList
有一个progress_bar
初始化参数,可以让您使用默认值。 training_model
是tf.keras.Model
的典型实例。
from tqdm.notebook import tqdm, trange
# Populate with typical keras callbacks
_callbacks = []
callbacks = tf.keras.callbacks.CallbackList(
_callbacks, add_history=True, model=training_model)
logs =
callbacks.on_train_begin(logs=logs)
# Presentation
epochs = trange(
max_epochs,
desc="Epoch",
unit="Epoch",
postfix="loss = loss:.4f, accuracy = accuracy:.4f")
epochs.set_postfix(loss=0, accuracy=0)
# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)
for epoch in epochs:
callbacks.on_epoch_begin(epoch, logs=logs)
# I like to formulate new batches each epoch
# if there are data augmentation methods in play
training_batches = batches(x, Y)
# Presentation
enumerated_batches = tqdm(
enumerate(training_batches),
desc="Batch",
unit="batch",
postfix="loss = loss:.4f, accuracy = accuracy:.4f",
position=1,
leave=False)
for (batch, (x, y)) in enumerated_batches:
training_model.reset_states()
callbacks.on_batch_begin(batch, logs=logs)
callbacks.on_train_batch_begin(batch, logs=logs)
logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)
callbacks.on_train_batch_end(batch, logs=logs)
callbacks.on_batch_end(batch, logs=logs)
# Presentation
enumerated_batches.set_postfix(
loss=float(logs["loss"]),
accuracy=float(logs["accuracy"]))
for (batch, (x, y)) in enumerate(test_batches):
training_model.reset_states()
callbacks.on_batch_begin(batch, logs=logs)
callbacks.on_test_batch_begin(batch, logs=logs)
logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)
callbacks.on_test_batch_end(batch, logs=logs)
callbacks.on_batch_end(batch, logs=logs)
# Presentation
epochs.set_postfix(
loss=float(logs["loss"]),
accuracy=float(logs["accuracy"]))
callbacks.on_epoch_end(epoch, logs=logs)
# NOTE: This is a decent place to check on your early stopping
# callback.
# Example: use training_model.stop_training to check for early stopping
callbacks.on_train_end(logs=logs)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
assert history_object is not None
【讨论】:
感谢您的广泛回答!这对我真的很有帮助。烦人的是没有更多关于自定义循环中回调的官方文档! 一份官方文档也指出,这种方式是最合适的。您可以查看示例部分。 tensorflow.org/api_docs/python/tf/keras/callbacks/Callback【参考方案2】:最简单的方法是检查损失是否在您的预期期间发生变化,如果没有,则中断或操纵训练过程。 这是实现自定义提前停止回调的一种方法:
def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
#No early stopping for 2*patience epochs
if len(LossList)//patience < 2 :
return False
#Mean loss for last patience epochs and second-last patience epochs
mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
mean_recent = np.mean(LossList[::-1][:patience]) #last
#you can use relative or absolute change
delta_abs = np.abs(mean_recent - mean_previous) #abs change
delta_abs = np.abs(delta_abs / mean_previous) # relative change
if delta_abs < min_delta :
print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
return True
else:
return False
此Callback_EarlyStopping
在每个时期检查您的指标/损失,如果相对变化小于您通过在每个 patience
时期后计算损失的移动平均值所期望的值,则返回 True
。然后,您可以捕获此True
信号并中断训练循环。为了完全回答您的问题,在您的示例训练循环中,您可以将其用作:
gen_loss_seq = []
for epoch in range(epochs):
#in your example, make sure your train_step returns gen_loss
gen_loss = train_step(dataset)
#ideally, you can have a validation_step and get gen_valid_loss
gen_loss_seq.append(gen_loss)
#check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
if stopEarly:
print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
print("Terminating training ")
break
当然,您可以通过多种方式增加复杂性,例如,您希望跟踪哪些损失或指标、您对特定时期损失的兴趣或损失的移动平均值、您对相对或绝对变化的兴趣值等。可以参考Tensorflow 2.x implementation of tf.keras.callbacks.EarlyStopping
here,一般用在流行的tf.keras.Model.fit
方法中。
【讨论】:
不幸的是,这个答案只适用于非常特殊的情况,即人们想要使用 EarlyStopping 回调。但是,还有更多其他有用的回调,人们可能希望重用而不是从头开始实现。【参考方案3】:我认为您需要手动实现回调的功能。应该不会太难。例如,您可以让“train_step”函数返回损失,然后实现回调功能,例如提前停止“train”函数。对于诸如学习率计划之类的回调,函数 tf.keras.backend.set_value(generator_optimizer.lr,new_lr) 会派上用场。因此,回调的功能将在您的“train”函数中实现。
【讨论】:
【参考方案4】:自定义训练循环只是一个普通的 Python 循环,因此您可以在满足某些条件时使用if
语句来中断循环。例如:
if len(loss_history) > patience:
if loss_history.popleft()*delta < min(loss_history):
print(f'\nEarly stopping. No improvement of more than delta:.5% in '
f'validation loss in the last patience epochs.')
break
如果delta%
在过去patience
epochs的损失没有改善,那么循环将被打破。在这里,我使用了collections.deque
,它可以很容易地用作滚动列表,仅在内存中保存最后的patience
epochs 信息。
这是一个完整的实现,带有来自 Tensorflow 文档的文档示例:
patience = 3
delta = 0.001
loss_history = deque(maxlen=patience + 1)
for epoch in range(1, 25 + 1):
train_loss = tf.metrics.Mean()
train_acc = tf.metrics.CategoricalAccuracy()
test_loss = tf.metrics.Mean()
test_acc = tf.metrics.CategoricalAccuracy()
for x, y in train:
loss_value, grads = get_grad(model, x, y)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss.update_state(loss_value)
train_acc.update_state(y, model(x, training=True))
for x, y in test:
loss_value, _ = get_grad(model, x, y)
test_loss.update_state(loss_value)
test_acc.update_state(y, model(x, training=False))
print(verbose.format(epoch,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()))
loss_history.append(test_loss.result())
if len(loss_history) > patience:
if loss_history.popleft()*delta < min(loss_history):
print(f'\nEarly stopping. No improvement of more than delta:.5% in '
f'validation loss in the last patience epochs.')
break
Epoch 1 Loss: 0.191 TLoss: 0.282 Acc: 68.920% TAcc: 89.200%
Epoch 2 Loss: 0.157 TLoss: 0.297 Acc: 70.880% TAcc: 90.000%
Epoch 3 Loss: 0.133 TLoss: 0.318 Acc: 71.560% TAcc: 90.800%
Epoch 4 Loss: 0.117 TLoss: 0.299 Acc: 71.960% TAcc: 90.800%
Early stopping. No improvement of more than 0.10000% in validation loss in the last 3 epochs.
【讨论】:
【参考方案5】:aapa3e8 的回答是正确的,但我在下面提供了一个与 tf.keras.callbacks.EarlyStopping 更相似的Callback_EarlyStopping
实现
def Callback_EarlyStopping(MetricList, min_delta=0.1, patience=20, mode='min'):
#No early stopping for the first patience epochs
if len(MetricList) <= patience:
return False
min_delta = abs(min_delta)
if mode == 'min':
min_delta *= -1
else:
min_delta *= 1
#last patience epochs
last_patience_epochs = [x + min_delta for x in MetricList[::-1][1:patience + 1]]
current_metric = MetricList[::-1][0]
if mode == 'min':
if current_metric >= max(last_patience_epochs):
print(f'Metric did not decrease for the last patience epochs.')
return True
else:
return False
else:
if current_metric <= min(last_patience_epochs):
print(f'Metric did not increase for the last patience epochs.')
return True
else:
return False
【讨论】:
不幸的是,问题不在于提前停止,而在于一般的回调。为什么这里的每个人都认为,提问者只想要这个特定的回调?以上是关于在 TensorFlow 2.0 的自定义训练循环中应用回调的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow 2 中用于自定义训练循环的 Tensorboard
如何在 TF 2.0 / 1.14.0-eager 和自定义训练循环(梯度磁带)中执行梯度累积?