在 TensorFlow 中获取数据集的长度

Posted

技术标签:

【中文标题】在 TensorFlow 中获取数据集的长度【英文标题】:Get length of a dataset in Tensorflow 【发布时间】:2018-05-23 23:17:57 【问题描述】:
source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

我想完全洗牌我的整个数据集,但shuffle() 需要提取大量样本,而tf.Size() 不适用于tf.data.Dataset

如何正确洗牌?

【问题讨论】:

它应该是您较小的 csv 文件的大小。我不知道 Tensorflow 中有返回数据集长度的函数或属性。 来自documentation:结果数据集中元素的数量与最小数据集的大小相同 zip() 的工作方式相同;当 StopIteration 被提出时迭代结束(由最短的对象)。 【参考方案1】:

我正在使用 tf.data.FixedLengthRecordDataset() 并遇到了类似的问题。 就我而言,我试图只获取一定比例的原始数据。 由于我知道所有记录都有固定长度,因此我的解决方法是:

totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)

在您的情况下,我的建议是直接在 python 中计算“primary.csv”和“secondary.csv”中的记录数。或者,我认为出于您的目的,设置 buffer_size 参数并不需要计算文件。根据the accepted answer about the meaning of buffer_size,大于数据集中元素数量的数字将确保整个数据集中的统一随机播放。因此,只需输入一个非常大的数字(您认为会超过数据集的大小)就可以了。

【讨论】:

【参考方案2】:

从 TensorFlow 2 开始,可以通过 cardinality() 函数轻松检索数据集的长度。

dataset = tf.data.Dataset.range(42)
#both print 42 
dataset_length_v1 = tf.data.experimental.cardinality(dataset).numpy())
dataset_length_v2 = dataset.cardinality().numpy()

注意:当使用谓词时,例如过滤器,返回的长度可能是-2。可以参考here的解释,否则就看下面这段:

如果使用过滤谓词,基数可能返回值-2,因此未知;如果您确实在数据集上使用过滤谓词,请确保您已经以另一种方式计算了数据集的长度(例如,在对其应用 .from_tensor_slices() 之前,pandas 数据帧的长度。

【讨论】:

这对我尝试过的两个数据集都给出了-2。 是的,这就是为什么

以上是关于在 TensorFlow 中获取数据集的长度的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow的MNIST教程

数据集的获取

如何在 Tensorflow.js 中获取预测值

怎样用sql语句获取某个字段的长度

获取预测模型在测试集中预测错误的数据样本

NLP中文酒店评论语料文本数据分析