传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数

Posted

技术标签:

【中文标题】传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数【英文标题】:When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument 【发布时间】:2020-03-09 14:15:25 【问题描述】:

我正在尝试使用这个 Google 的示例,但使用的是我自己的数据集:

https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/text_classification.ipynb

我创建了一个文件夹,类似于在他们的代码中下载的内容,其中包含训练和测试文件夹以及 txt 文件。

在我的例子中,data_path 如下: data_path = '/Users/developer/.keras/datasets/chat'

每当我尝试运行它时model = text_classifier.create(train_data) 都会抛出一个错误 ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument. 这甚至意味着什么,我应该在哪里寻找问题?


import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_customization.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_customization.core.model_export_format import ModelExportFormat
import tensorflow_examples.lite.model_customization.core.task.text_classifier as text_classifier


# data_path = tf.keras.utils.get_file(
#       fname='aclImdb',
#       origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
#       untar=True)

data_path = '/Users/developer/.keras/datasets/chat'

train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), class_labels=['greeting', 'goodbye'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), shuffle=False)

model = text_classifier.create(train_data)
loss, acc = model.evaluate(test_data)
model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')

【问题讨论】:

【参考方案1】:

我遇到了类似的问题,然后在 model.fit 下我添加了steps_per_epoch

history = single_step_model.fit(train_data_single,
                                epochs=100, 
                                callbacks=[lr_schedule], 
                                steps_per_epoch=EVALUATION_INTERVAL)

当然,我在此之前输入了EVALUATION_INTERVAL 的值,因此它起作用了。希望对您有所帮助。

【讨论】:

【参考方案2】:

问题是,当您为所需数量的 epoch 训练模型时,您的训练代码部分可能无法确定特定 epoch 何时开始以及该 epoch 何时结束。

因此,在训练期间,可以添加一个“steps_per_epoch”参数,以便它知道如何针对单个 epoch 的特定有限步数进行操作和训练。

在验证的情况下,我们会添加特定的“validation_steps”来解决相同的问题。

我通过向我的 tf.Keras model.fit() 代码添加 steps_per_epoch 和 validation_steps 参数解决了这个问题。

需要总结如何在代码中提供这些参数。

参考资料:

https://keras.io/api/models/model_training_apis/#fit-method https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/

【讨论】:

以上是关于传递无限重复数据集时,您必须指定 `steps_per_epoch` 参数的主要内容,如果未能解决你的问题,请参考以下文章

当 sink 是二进制数据集时,源必须是二进制

无限网格滚动和局部过滤

处理 SSAS 多维数据集时发生异常 [重复]

如何在php中通过引用传递无限参数[重复]

crypt():没有指定盐参数。您必须使用随机生成的盐和强哈希函数来生成安全哈希 [重复]

SparkSQL:如何在从数据库加载数据集时指定分区列