Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法

Posted 琥珀彩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法相关的知识,希望对你有一定的参考价值。

1.Dataset数据对象

Dataset可以用来表示输入管道元素集合(张量的嵌套结构)和“逻辑计划“对这些元素的转换操作。在Dataset中元素可以是向量,元组或字典等形式。
另外,Dataset需要配合另外一个类Iterator进行使用,Iterator对象是一个迭代器,可以对Dataset中的元素进行迭代提取。

2.Dataset方法

2.1 产生数据集
2.1.1. from_tensor_slices

from_tensor_slices 用于创建dataset,其元素是给定张量的切片的元素。

函数形式:from_tensor_slices(tensors)

参数tensors:张量的嵌套结构,每个都在第0维中具有相同的大小。

import tensorflow as tf
#创建一个Dataset对象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9,10,11,12])
'''合成批次'''
dataset=dataset.batch(5)
#创建一个迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函数可以帮助我们从迭代器中获取元素
element = iterator.get_next()
 
#遍历迭代器,获取所有元素
with tf.Session() as sess:   
    for i in range(9):
       print(sess.run(element))  

输出

[1 2 3 4 5]
[ 6  7  8  9 10]
[11 12]

2.1.2 .from_tensors

创建一个Dataset包含给定张量的单个元素。

函数形式:from_tensors(tensors)

参数tensors:张量的嵌套结构。

dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])
 
iterator = concat_dataset.make_one_shot_iterator()
 
element = iterator.get_next()
 
with tf.Session() as sess:   
    for i in range(1):
       print(sess.run(element))

区别:

  • from_tensors是将tensors作为一个整体进行操纵,而from_tensor_slices可以操纵tensors里面的元素。

2.1.3 from_generator(具体实践不太了解)

创建Dataset由其生成元素的元素generator。

函数形式:from_generator(generator,output_types,output_shapes=None,args=None)

参数generator:一个可调用对象,它返回支持该iter()协议的对象 。如果args未指定,generator则不得参数; 否则它必须采取与有值一样多的参数args。
参数output_types:tf.DType对应于由元素生成的元素的每个组件的对象的嵌套结构generator。
参数output_shapes:tf.TensorShape 对应于由元素生成的元素的每个组件的对象 的嵌套结构generator
参数args:tf.Tensor将被计算并将generator作为NumPy数组参数传递的对象元组。

 

2.2 数据转换Transformation
2.2.1 batch

# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。

或者将for循环改为:while True:。就不用设置循环次数了。

 

2.2.2 shuffle


上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。

注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱(实际并未shuffle)。

随机混洗数据集的元素。

函数形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

参数buffer_size:表示新数据集将从中采样的数据集中的元素数。

buffer_size=1:不打乱顺序,既保持原序
buffer_size越大,打乱程度越大
参数seed:(可选)表示将用于创建分布的随机种子。
参数reshuffle_each_iteration:(可选)一个布尔值,如果为true,则表示每次迭代时都应对数据集进行伪随机重组。(默认为True。)

在这里buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从dataset中提取数据将它填满。batch操作,从buffer中提取。

如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size= Dataset_size。

 

2.2.3 map


map可以将map_func函数映射到数据集.
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,

函数形式:flat_map(map_func,num_parallel_calls=None)

参数map_func:映射函数
参数num_parallel_calls:表示要并行处理的数字元素。如果未指定,将按顺序处理元素。如果使用值tf.data.experimental.AUTOTUNE,则根据可用的CPU动态设置并行调用的数量。

对dataset中每个元素的值加10

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

[16 17 18 19]
[10 11 12 13 14 15]


2.2.4 repeat

重复此数据集次数,主要用来处理机器学习中的epoch,假设原先的数据训练一个epoch,使用repeat(2)就可以将之变成2个epoch,默认空是无限次。

函数形式:repeat(count=None)

参数count:(可选)表示数据集应重复的次数。默认行为(如果count是None或-1)是无限期重复的数据集。

 


————————————————
版权声明:本文为CSDN博主「rrr2」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35608277/article/details/116333888

以上是关于Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 对象检测 API 的数据增强是不是会产生比原始样本更多的样本?

在自己的数据集上训练 TensorFlow 对象检测

Tensorflow 对象检测 API - 验证丢失行为

Tensorflow 对象检测 API 中的过拟合

具有奇怪检测结果的 TensorFlow 对象检测 api

如何在 tensorflow 对象检测 API 中使用“忽略”类?