如何在 keras 自定义回调中访问 tf.data.Dataset?
Posted
技术标签:
【中文标题】如何在 keras 自定义回调中访问 tf.data.Dataset?【英文标题】:how to access tf.data.Dataset within a keras custom callback? 【发布时间】:2021-01-15 14:47:18 【问题描述】:我编写了一个自定义 keras 回调来检查来自生成器的增强数据。 (有关完整代码,请参阅 this answer。)但是,当我尝试对 tf.data.Dataset
使用相同的回调时,它给了我一个错误:
File "/path/to/tensorflow_image_callback.py", line 16, in on_batch_end
imgs = self.train[batch][images_or_labels]
TypeError: 'PrefetchDataset' object is not subscriptable
keras 回调一般只适用于生成器,还是与我编写的方式有关?有没有办法修改我的回调或数据集以使其工作?
我认为这个难题可以分为三个部分。我对任何和所有的更改持开放态度。一、自定义回调类中的init函数:
class TensorBoardImage(tf.keras.callbacks.Callback):
def __init__(self, logdir, train, validation=None):
super(TensorBoardImage, self).__init__()
self.logdir = logdir
self.file_writer = tf.summary.create_file_writer(logdir)
self.train = train
self.validation = validation
其次,同一类中的on_batch_end
函数
def on_batch_end(self, batch, logs):
images_or_labels = 0 #0=images, 1=labels
imgs = self.train[batch][images_or_labels]
三、实例化回调
import tensorflow_image_callback
tensorboard_image_callback = tensorflow_image_callback.TensorBoardImage(logdir=tensorboard_log_dir, train=train_dataset, validation=valid_dataset)
model.fit(train_dataset,
epochs=n_epochs,
validation_data=valid_dataset,
callbacks=[
tensorboard_callback,
tensorboard_image_callback
])
一些尚未让我找到答案的相关主题:
Accessing validation data within a custom callback
Create keras callback to save model predictions and targets for each batch during training
【问题讨论】:
这个有什么更新吗? 【参考方案1】:最终为我工作的是以下内容,使用tfds
:
__init__
函数:
def __init__(self, logdir, train, validation=None):
super(TensorBoardImage, self).__init__()
self.logdir = logdir
self.file_writer = tf.summary.create_file_writer(logdir)
# #from keras generator
# self.train = train
# self.validation = validation
#from tf.Data
my_data = tfds.as_numpy(train)
imgs = my_data['image']
然后on_batch_end
:
def on_batch_end(self, batch, logs):
images_or_labels = 0 #0=images, 1=labels
imgs = self.train[batch][images_or_labels]
#calculate epoch
n_batches_per_epoch = self.train.samples / self.train.batch_size
epoch = math.floor(self.train.total_batches_seen / n_batches_per_epoch)
#since the training data is shuffled each epoch, we need to use the index_array to find something which uniquely
#identifies the image and is constant throughout training
first_index_in_batch = batch * self.train.batch_size
last_index_in_batch = first_index_in_batch + self.train.batch_size
last_index_in_batch = min(last_index_in_batch, len(self.train.index_array))
img_indices = self.train.index_array[first_index_in_batch : last_index_in_batch]
with self.file_writer.as_default():
for ix,img in enumerate(imgs):
#only post 1 out of every 1000 images to tensorboard
if (img_indices[ix] % 1000) == 0:
#instead of img_filename, I could just use str(img_indices[ix]) as a unique identifier
#but this way makes it easier to find the unaugmented image
img_filename = self.train.filenames[img_indices[ix]]
#convert float to uint8, shift range to 0-255
img -= tf.reduce_min(img)
img *= 255 / tf.reduce_max(img)
img = tf.cast(img, tf.uint8)
img_tensor = tf.expand_dims(img, 0) #tf.summary needs a 4D tensor
tf.summary.image(img_filename, img_tensor, step=epoch)
我不需要对实例化进行任何更改。
我建议仅将它用于调试,否则它会将数据集中的每第 n 个图像保存到每个 epoch 的 tensorboard。这最终可能会占用大量磁盘空间。
【讨论】:
以上是关于如何在 keras 自定义回调中访问 tf.data.Dataset?的主要内容,如果未能解决你的问题,请参考以下文章
如何在生成器提供的 Keras 自定义损失函数中访问样本权重?
F1 比在 keras 回调中使用 sklearn 的准确率更高。有问题?