传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数
Posted
技术标签:
【中文标题】传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数【英文标题】:When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument 【发布时间】:2020-03-09 14:15:25 【问题描述】:我正在尝试使用这个 Google 的示例,但使用的是我自己的数据集:
https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/text_classification.ipynb
我创建了一个文件夹,类似于在他们的代码中下载的内容,其中包含训练和测试文件夹以及 txt 文件。
在我的例子中,data_path 如下:
data_path = '/Users/developer/.keras/datasets/chat'
每当我尝试运行它时model = text_classifier.create(train_data)
都会抛出一个错误
ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument.
这甚至意味着什么,我应该在哪里寻找问题?
import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')
from tensorflow_examples.lite.model_customization.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_customization.core.model_export_format import ModelExportFormat
import tensorflow_examples.lite.model_customization.core.task.text_classifier as text_classifier
# data_path = tf.keras.utils.get_file(
# fname='aclImdb',
# origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
# untar=True)
data_path = '/Users/developer/.keras/datasets/chat'
train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), class_labels=['greeting', 'goodbye'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), shuffle=False)
model = text_classifier.create(train_data)
loss, acc = model.evaluate(test_data)
model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')
【问题讨论】:
【参考方案1】:我遇到了类似的问题,然后在 model.fit 下我添加了steps_per_epoch
history = single_step_model.fit(train_data_single,
epochs=100,
callbacks=[lr_schedule],
steps_per_epoch=EVALUATION_INTERVAL)
当然,我在此之前输入了EVALUATION_INTERVAL
的值,因此它起作用了。希望对您有所帮助。
【讨论】:
【参考方案2】:问题是,当您为所需数量的 epoch 训练模型时,您的训练代码部分可能无法确定特定 epoch 何时开始以及该 epoch 何时结束。
因此,在训练期间,可以添加一个“steps_per_epoch”参数,以便它知道如何针对单个 epoch 的特定有限步数进行操作和训练。
在验证的情况下,我们会添加特定的“validation_steps”来解决相同的问题。
我通过向我的 tf.Keras model.fit() 代码添加 steps_per_epoch 和 validation_steps 参数解决了这个问题。
需要总结如何在代码中提供这些参数。
参考资料:
https://keras.io/api/models/model_training_apis/#fit-method https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/【讨论】:
以上是关于传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数的主要内容,如果未能解决你的问题,请参考以下文章