Tensorflow:如何查找 tf.data.Dataset API 对象的大小
Posted
技术标签:
【中文标题】Tensorflow:如何查找 tf.data.Dataset API 对象的大小【英文标题】:Tensorflow: How to find the size of a tf.data.Dataset API object 【发布时间】:2018-11-27 21:14:57 【问题描述】:我了解 Dataset API 是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小。我说的是存储在文本文件或 tfRecord 文件中的大型数据语料库。这些文件通常使用tf.data.TextLineDataset
或类似的东西读取。查找使用 tf.data.Dataset.from_tensor_slices
加载的数据集的大小很简单。
我询问数据集大小的原因如下: 假设我的数据集大小是 1000 个元素。批量大小 = 50 个元素。然后训练步骤/批次(假设 1 个 epoch)= 20。在这 20 个步骤中,我想以指数方式将我的学习率从 0.1 衰减到 0.01
tf.train.exponential_decay(
learning_rate = 0.1,
global_step = global_step,
decay_steps = 20,
decay_rate = 0.1,
staircase=False,
name=None
)
在上面的代码中,我有“和”想设置decay_steps = number of steps/batches per epoch = num_elements/batch_size
。只有事先知道数据集中元素的数量,才能计算出这一点。
提前知道大小的另一个原因是使用tf.data.Dataset.take()
、tf.data.Dataset.skip()
方法将数据分成训练集和测试集。
PS:我不是在寻找暴力方法,例如遍历整个数据集并更新计数器以计算元素数量或 putting a very large batch size and then finding the size of the resultant dataset 等。
【问题讨论】:
【参考方案1】:您可以手动指定数据集的大小吗?
我如何加载我的数据:
sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")
sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
num_samples = tf.size(sample_ids)
data = tf.data.Dataset.from_tensor_slices(sample_ids)
# "load" data by id:
# return (id, data) for each id
data = data.map(
lambda id: (id, some_load_op(id))
)
在这里,您可以通过使用占位符初始化一次 sample_ids
来指定所有示例 ID。
您的样本 ID 可能是例如文件路径或简单数字 (np.arange(num_elems)
)
然后可以在num_samples
中获得元素的数量。
【讨论】:
感谢您的回答。但是,您似乎没有得到我的问题,对此感到抱歉。我将再次修改我的问题。我没有使用from_tensor_slices
,它在您有小数据集时使用。在这种情况下,找到数据集的大小是微不足道的。我的问题是关于读取存储在文本文件中的大量数据,这些数据只能使用tf.data.TextLineDataset()
读取。在这种情况下,如何求出整个数据集的大小?【参考方案2】:
您可以使用以下方法轻松获取数据样本数:
dataset.__len__()
你可以像这样获取每个元素:
for step, element in enumerate(dataset.as_numpy_iterator()):
... print(step, element)
你也可以得到一个样本的形状:
dataset.element_spec
如果你想获取特定的元素,你也可以使用 shard 方法。
【讨论】:
dataset.__len__()
不适用于使用tf.data.TextLineDataset
创建的数据集。错误是TypeError: dataset length is unknown.
我发现 TextLineDataset 读取每个文件并将该文件的每一行作为样本,所以你可以做的是 len(dataset.as_numpy_iterator())。但是,每行的字符数会有所不同,然后您可以使用一些固定的大小使其可训练。【参考方案3】:
我知道这个问题已经有两年了,但也许这个答案会有用。
如果您使用tf.data.TextLineDataset
读取数据,那么获取样本数量的一种方法可能是计算您正在使用的所有文本文件中的行数。
考虑以下示例:
import random
import string
import tensorflow as tf
filenames = ["data0.txt", "data1.txt", "data2.txt"]
# Generate synthetic data.
for filename in filenames:
with open(filename, "w") as f:
lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
print("\n".join(lines), file=f)
dataset = tf.data.TextLineDataset(filenames)
尝试使用len
获取长度会引发TypeError
:
len(dataset)
但是可以相对快速地计算出文件中的行数。
# https://***.com/q/845058/5666087
def get_n_lines(filepath):
i = -1
with open(filepath) as f:
for i, _ in enumerate(f):
pass
return i + 1
n_lines = sum(get_n_lines(f) for f in filenames)
在上面,n_lines
等于迭代数据集时找到的元素数
for i, _ in enumerate(dataset):
pass
n_lines == i + 1
【讨论】:
【参考方案4】:这是我为解决问题所做的,将以下行添加到您的数据集tf.data.experimental.assert_cardinality(len_of_data)
这将解决问题,
ast = Audioset(df) # the generator class
db = tf.data.Dataset.from_generator(ast, output_types=(tf.float32, tf.float32, tf.int32))
db = db.apply(tf.data.experimental.assert_cardinality(len(ast))) # number of samples
db = db.batch(batch_size)
数据集 len 根据 batch_size 改变,要获得数据集 len 只需运行 len(db)
。
查看here了解更多详情
【讨论】:
以上是关于Tensorflow:如何查找 tf.data.Dataset API 对象的大小的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Tensorflow 对象检测 API 中查找边界框坐标