如何在 Tensorflow 2 中的模型训练期间捕获任何异常
Posted
技术标签:
【中文标题】如何在 Tensorflow 2 中的模型训练期间捕获任何异常【英文标题】:How to catch any Exception during Model Training in Tensorflow 2 【发布时间】:2020-02-29 16:07:10 【问题描述】:我正在使用 Tensorflow 训练一个 Unet 模型。如果我传递给模型进行训练的任何图像存在问题,则会引发异常。有时这可能会在训练后的一两个小时内发生。将来是否有可能捕获任何此类异常,以便我的模型可以继续下一张图像并恢复训练?我尝试将try/catch
块添加到如下所示的process_path
函数中,但这没有效果...
def process_path(filePath):
# catching exceptions here has no effect
parts = tf.strings.split(filePath, '/')
fileName = parts[-1]
parts = tf.strings.split(fileName, '.')
prefix = tf.convert_to_tensor(maskDir, dtype=tf.string)
suffix = tf.convert_to_tensor("-mask.png", dtype=tf.string)
maskFileName = tf.strings.join((parts[-2], suffix))
maskPath = tf.strings.join((prefix, maskFileName), separator='/')
# load the raw data from the file as a string
img = tf.io.read_file(filePath)
img = decode_img(img)
mask = tf.io.read_file(maskPath)
oneHot = decodeMask(mask)
img.set_shape([256, 256, 3])
oneHot.set_shape([256, 256, 10])
return img, oneHot
trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
batchSize = 32
allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))
trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.shuffle(1000).repeat()
trainDataSet = trainDataSet.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataSet = trainDataSet.batch(batchSize)
trainDataSet = trainDataSet.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.shuffle(1000).repeat()
validDataSet = validDataSet.map(process_path)
validDataSet = validDataSet.batch(batchSize)
imageHeight = 256
imageWidth = 256
channels = 3
inputImage = Input((imageHeight, imageWidth, channels), name='img')
model = baseUnet.get_unet(inputImage, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
callbacks = [
EarlyStopping(patience=5, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
ModelCheckpoint(outputModel, verbose=1, save_best_only=True, save_weights_only=False)
]
BATCH_SIZE = 32
BUFFER_SIZE = 1000
EPOCHS = 20
stepsPerEpoch = int(trainSize / BATCH_SIZE)
validationSteps = int(validSize / BATCH_SIZE)
model_history = model.fit(trainDataSet, epochs=EPOCHS,
steps_per_epoch=stepsPerEpoch,
validation_steps=validationSteps,
validation_data=validDataSet,
callbacks=callbacks)
下面的link 显示了一个类似的案例,并解释了“Python 函数只执行一次来构建函数图,然后 try 和 except 语句将不起作用。”虽然链接显示了如何遍历数据集并捕获错误...
dataset = ...
iterator = iter(dataset)
while True:
try:
elem = next(iterator)
...
except InvalidArgumentError:
...
except StopIteration:
break
...不过,我正在寻找一种在训练期间发现错误的方法。这可能吗?
【问题讨论】:
您找到解决方案了吗? 【参考方案1】:您可能会考虑使用tf.data.experimental.ignore_errors
函数来静默删除导致问题的文件
【讨论】:
以上是关于如何在 Tensorflow 2 中的模型训练期间捕获任何异常的主要内容,如果未能解决你的问题,请参考以下文章
训练CNN模型图像分类期间的tensorflow NaN损失
AI - TensorFlow - 示例05:保存和恢复模型
如何在 Keras 中的预训练 InceptionResNetV2 模型的不同层中找到激活的形状 - Tensorflow 2.0