Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法
Posted 琥珀彩
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法相关的知识,希望对你有一定的参考价值。
1.Dataset数据对象
Dataset可以用来表示输入管道元素集合(张量的嵌套结构)和“逻辑计划“对这些元素的转换操作。在Dataset中元素可以是向量,元组或字典等形式。
另外,Dataset需要配合另外一个类Iterator进行使用,Iterator对象是一个迭代器,可以对Dataset中的元素进行迭代提取。
2.Dataset方法
2.1 产生数据集
2.1.1. from_tensor_slices
from_tensor_slices 用于创建dataset,其元素是给定张量的切片的元素。
函数形式:from_tensor_slices(tensors)
参数tensors:张量的嵌套结构,每个都在第0维中具有相同的大小。
import tensorflow as tf
#创建一个Dataset对象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9,10,11,12])
'''合成批次'''
dataset=dataset.batch(5)
#创建一个迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函数可以帮助我们从迭代器中获取元素
element = iterator.get_next()
#遍历迭代器,获取所有元素
with tf.Session() as sess:
for i in range(9):
print(sess.run(element))
输出
[1 2 3 4 5]
[ 6 7 8 9 10]
[11 12]
2.1.2 .from_tensors
创建一个Dataset包含给定张量的单个元素。
函数形式:from_tensors(tensors)
参数tensors:张量的嵌套结构。
dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])
iterator = concat_dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
for i in range(1):
print(sess.run(element))
区别:
- from_tensors是将tensors作为一个整体进行操纵,而from_tensor_slices可以操纵tensors里面的元素。
2.1.3 from_generator(具体实践不太了解)
创建Dataset由其生成元素的元素generator。
函数形式:from_generator(generator,output_types,output_shapes=None,args=None)
参数generator:一个可调用对象,它返回支持该iter()协议的对象 。如果args未指定,generator则不得参数; 否则它必须采取与有值一样多的参数args。
参数output_types:tf.DType对应于由元素生成的元素的每个组件的对象的嵌套结构generator。
参数output_shapes:tf.TensorShape 对应于由元素生成的元素的每个组件的对象 的嵌套结构generator
参数args:tf.Tensor将被计算并将generator作为NumPy数组参数传递的对象元组。
2.2 数据转换Transformation
2.2.1 batch
# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。
或者将for循环改为:while True:。就不用设置循环次数了。
2.2.2 shuffle
上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。
注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱(实际并未shuffle)。
随机混洗数据集的元素。
函数形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)
参数buffer_size:表示新数据集将从中采样的数据集中的元素数。
buffer_size=1:不打乱顺序,既保持原序
buffer_size越大,打乱程度越大
参数seed:(可选)表示将用于创建分布的随机种子。
参数reshuffle_each_iteration:(可选)一个布尔值,如果为true,则表示每次迭代时都应对数据集进行伪随机重组。(默认为True。)
在这里buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从dataset中提取数据将它填满。batch操作,从buffer中提取。
如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size= Dataset_size。
2.2.3 map
map可以将map_func函数映射到数据集.
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,
函数形式:flat_map(map_func,num_parallel_calls=None)
参数map_func:映射函数
参数num_parallel_calls:表示要并行处理的数字元素。如果未指定,将按顺序处理元素。如果使用值tf.data.experimental.AUTOTUNE,则根据可用的CPU动态设置并行调用的数量。
对dataset中每个元素的值加10
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
[16 17 18 19]
[10 11 12 13 14 15]
2.2.4 repeat
重复此数据集次数,主要用来处理机器学习中的epoch,假设原先的数据训练一个epoch,使用repeat(2)就可以将之变成2个epoch,默认空是无限次。
函数形式:repeat(count=None)
参数count:(可选)表示数据集应重复的次数。默认行为(如果count是None或-1)是无限期重复的数据集。
————————————————
版权声明:本文为CSDN博主「rrr2」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35608277/article/details/116333888
以上是关于Tensorflow 数据对象Dataset.shuffle()repeat()batch() 等用法的主要内容,如果未能解决你的问题,请参考以下文章