使用@tf.function 进行自定义张量流训练的内存泄漏
Posted
技术标签:
【中文标题】使用@tf.function 进行自定义张量流训练的内存泄漏【英文标题】:Memory leak for custom tensorflow training using @tf.function 【发布时间】:2021-07-10 23:07:15 【问题描述】:我正在尝试按照 Keras 官方演练为 TF2/Keras
编写自己的训练循环。 vanilla 版本就像一个魅力,但是当我尝试将 @tf.function
装饰器添加到我的训练步骤时,一些内存泄漏会占用我所有的内存并且我失去对我的机器的控制,有人知道发生了什么吗?。
代码的重要部分如下所示:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = siamese_network(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, siamese_network.trainable_weights)
optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
val_logits = siamese_network(x, training=False)
val_acc_metric.update_state(y, val_logits)
val_prec_metric.update_state(y_batch_val, val_logits)
val_rec_metric.update_state(y_batch_val, val_logits)
for epoch in range(epochs):
step_time = 0
epoch_time = time.time()
print("Start of epoch".format(epoch))
for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
if step > steps_epoch:
break
loss_value = train_step(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
train_acc_metric.reset_states()
for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
if val_step>validation_steps:
break
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_prec = val_prec_metric.result()
val_rec = val_rec_metric.result()
val_acc_metric.reset_states()
val_prec_metric.reset_states()
val_rec_metric.reset_states()
如果我评论@tf.function
行,则不会发生内存泄漏,但步骤时间慢了 3 倍。我的猜测是,不知何故,图表是在每个时期或类似的情况下再次创建的 bean,但我不知道如何解决它。
这是我正在学习的教程:https://keras.io/guides/writing_a_training_loop_from_scratch/
【问题讨论】:
您使用的是 GPU 吗?如果不是,则将其更改为 GPU。另外,尽量减少批量大小。 您的train_ds
和test_ds
是如何创建的?当您枚举它们时,您会得到张量还是其他类型?
【参考方案1】:
tl;博士;
TensorFlow 可能会为传递给修饰函数的每个唯一参数值集生成一个新图。确保将形状一致的 Tensor
对象传递给 test_step
和 train_step
而不是 python 对象。
详情
这是在黑暗中刺伤。虽然我从未尝试过@tf.function
,但我确实在the documentation 中发现了以下警告:
tf.function 还将任何纯 Python 值视为不透明对象,并为其遇到的每组 Python 参数构建一个单独的图。
和
警告:将 python 标量或列表作为参数传递给 tf.function 将始终构建一个新图。为避免这种情况,请尽可能将数字参数作为张量传递
最后:
函数通过从输入的 args 和 kwargs 计算缓存键来确定是否重用跟踪的 ConcreteFunction。缓存键是根据函数调用的输入 args 和 kwargs 标识 ConcreteFunction 的键,根据以下规则(可能会更改):
为 tf.Tensor 生成的键是它的 shape 和 dtype。 为 tf.Variable 生成的键是唯一的变量 id。 为 Python 原语(如 int、float、str)生成的键是它的值。 为嵌套字典、列表、元组、namedtuples 和 attrs 生成的键是叶键的扁平元组(请参阅nest.flatten)。 (由于这种扁平化,调用具有与跟踪期间使用的嵌套结构不同的嵌套结构的具体函数将导致 TypeError)。 对于所有其他 Python 类型,键对于对象是唯一的。这样,一个函数或方法就可以独立地跟踪每个调用它的实例。
我从这一切中得到的是,如果您没有将大小一致的张量对象传递给您的 @tf.function
-ified 函数(也许您使用 Python 集合或原语),那么您很可能正在创建函数的新图形版本,其中包含您传入的每个不同的参数值。我猜这可能会产生您所看到的内存爆炸行为。我不知道您的 test_ds
和 train_ds
对象是如何创建的,但您可能希望确保创建它们以便 enumerate(blah_ds)
像教程中一样返回张量,或者至少将值转换为张量在传递给您的 test_step
和 train_step
函数之前。
【讨论】:
嗨!我正在训练的网络是连体网络,因此我有两个输入。最初,它们被表示为字典,实际上是一个 python 对象。我尝试使用x_prueba = tf.convert_to_tensor(x_prueba)
将它们转换为张量,并将这个新变量用作 train_step 函数的输入,但内存使用量仍在增长。它们属于 eagerTensor 类型,但不知道这是否相关。
我使用的是 TF 1.15,但是当我运行 tf.convert_to_tensor(my_dict)
时,我得到了 TypeError: Failed to convert object of type <class 'dict'> to Tensor.
。可能是版本不同吧。你确定你传递的是张量而不是张量字典吗?
我的意思是将张量而不是张量字典传递给您的 tf.function
-decorated 函数...
我必须做一些预处理。我不直接将字典转换为张量:我将字典的两个条目附加到一个列表中,然后将该列表转换为维度为 [2,35,280,4] 的张量,其中“2”是额外维度添加。这个张量是使用 TF 2.4 的 train_step 函数 Im 的 x 输入,以防万一
在那种情况下,听起来我的猜测是错误的。以上是关于使用@tf.function 进行自定义张量流训练的内存泄漏的主要内容,如果未能解决你的问题,请参考以下文章