使用tensorflow中的Dataset来读取制作好的tfrecords文件
Posted daremosiranaihana
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用tensorflow中的Dataset来读取制作好的tfrecords文件相关的知识,希望对你有一定的参考价值。
上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中可能就无法使用了,因此不建议大家使用这个函数,好了,下面就来看看是如何进行读取的吧。
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 #定义可以一次获得多张图像的函数 5 def show_image(image_dir): 6 plt.imshow(image_dir) 7 plt.axis(‘on‘) 8 plt.show() 9 10 #单个record的解析函数 11 def decode_example(example):#,resize_height,resize_width,labels_nums): 12 features=tf.io.parse_single_example(example,features= 13 ‘image_raw‘:tf.io.FixedLenFeature([],tf.string), 14 ‘label‘:tf.io.FixedLenFeature([],tf.int64) 15 ) 16 tf_image=tf.decode_raw(features[‘image_raw‘],tf.uint8)#这个其实就是图像的像素模式,之前我们使用矩阵来表示图像 17 tf_image=tf.reshape(tf_image,shape=[224,224,3])#对图像的尺寸进行调整,调整成三通道图像 18 tf_image=tf.cast(tf_image,tf.float32)*(1./255)#对图像进行归一化以便保持和原图像有相同的精度 19 tf_label=tf.cast(features[‘label‘],tf.int32) 20 tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#将label转化成用one_hot编码的格式 21 return tf_image,tf_label 22 23 def batch_test(tfrecords_file): 24 dataset=tf.data.TFRecordDataset(tfrecords_file) 25 dataset=dataset.map(decode_example) 26 dataset=dataset.shuffle(100).batch(4) 27 iterator=tf.compat.v1.data.make_one_shot_iterator(dataset) 28 batch_images,batch_labels=iterator.get_next() 29 30 init_op=tf.compat.v1.global_variables_initializer() 31 with tf.compat.v1.Session() as sess: 32 sess.run(init_op) 33 coord=tf.train.Coordinator() 34 threads=tf.train.start_queue_runners(coord=coord) 35 for i in range(4): 36 images,labels=sess.run([batch_images,batch_labels]) 37 show_image(images[1,:,:,:]) 38 print(‘shape:,tpye:,labels:‘.format(images.shape, images.dtype, labels)) 39 40 coord.request_stop() 41 coord.join(threads) 42 43 if __name__==‘__main__‘: 44 tfrecords_file=‘D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords‘ 45 resize_height=224 46 resize_width=224 47 batch_test(tfrecords_file)
我为了测试,写了batch_test这个函数,因为我想试一试看我做的tfrecords能不能被解析成功,如果你不想测试只想训练,那你直接把images_batch,和labels_batch放到网络中进行训练就可以了,还有一点要注意的,tf.global_variables_initializer()已经被tf.compat.v1.global_variables_initializer()所取代了,我做的时候不知道所以报了一个warning提示,同时tf.Sesssion()已经被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已经被tf.compat.v1.data.make_one_shot_iterator(dataset) 所代替,这些异常要注意,然后我只是将每个batch的第二张图片显示出来了,你也可以显示其他的,但是意义不大,反正只是测试一下解析成功与否,成功了我们就不需要纠结别的了。好啦,就是这样,接下来我会把这些东西放到网络中进行训练,再更新我的学习,就酱。
以上是关于使用tensorflow中的Dataset来读取制作好的tfrecords文件的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow - tf.data.Dataset 读取大型 HDF5 文件
深度学习(tensorflow) —— 自己数据集读取opencv
TensorFlow学习(十五):使用tf.data来创建输入流(上)
TensorFlow学习(十五):使用tf.data来创建输入流(上)