Tensorflow : 读取数据的三种方式及tfrecord的使用

Posted 明天去哪

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow : 读取数据的三种方式及tfrecord的使用相关的知识,希望对你有一定的参考价值。

参考: https://blog.csdn.net/lujiandong1/article/details/53376802
https://blog.csdn.net/happyhorizion/article/details/77894055

读取数据的三种方式

Preloaded data: 预加载数据

import tensorflow as tf  
# 设计Graph  
x1 = tf.constant([2, 3, 4])  
x2 = tf.constant([4, 0, 1])  
y = tf.add(x1, x2)  
# 打开一个session --> 计算y  
with tf.Session() as sess:  
    print sess.run(y) 

Feeding: Python产生数据,再把数据喂给后端

import tensorflow as tf  
# 设计Graph  
x1 = tf.placeholder(tf.int16)  
x2 = tf.placeholder(tf.int16)  
y = tf.add(x1, x2)  
# 用Python产生数据  
li1 = [2, 3, 4]  
li2 = [4, 0, 1]  
# 打开一个session --> 喂数据 --> 计算y  
with tf.Session() as sess:  
    print sess.run(y, feed_dict=x1: li1, x2: li2)  

Reading from file: 从文件中直接读取

#-*- coding:utf-8 -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义Reader
reader = tf.TextLineReader()
# key, value会分别得到文件名及行数和文件内容:['A.csv:3', 'Alpha3,A3']
key, value = reader.read(filename_queue)
# 定义Decoder
# sess.run([example, label])
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)
# 运行Graph
with tf.Session() as sess:
    coord = tf.train.Coordinator()  #创建一个协调器,管理线程
    threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。
    for i in range(10):
        train_1, label_1 = sess.run([example_batch, label_batch])
        print train_1, label_1
    coord.request_stop()
    coord.join(threads)

也可以使用shuffle_batch来实现,其中shuffle_batch不使用时数据是按照顺序的。
batch和batch_join的区别: 一般来说,单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch);而对于多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有对应的tf.train.shuffle_batch_join使用),与多个reader对应

#-*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
#定义了多种解码器,每个解码器跟一个reader相连
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                  for _ in range(2)]  # Reader设置为2
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
example_batch, label_batch = tf.train.shuffle_batch_join(
      example_list, batch_size=5, capacity=200,
                       min_after_dequeue=10)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        e_val,l_val = sess.run([example_batch,label_batch])
        print e_val,l_val
    coord.request_stop()
    coord.join(threads)

使用多个reader可以并行读取数据,提高效率
使用join可以保证实例与标签对应
也可以在代码中设置epochs

tfrecords

除了使用csv或者其他格式的数据,推荐使用tf内定标准格式——tfrecords

制作tfrecords

import tensorflow as tf
import numpy as np

# create tfrecord_writer
tfrecords_filename = 'test_tf.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for i in range(100):
    img_raw = np.random.random_integers(0, 255, size=(7, 30))
    img_raw = img_raw.tostring()
    example = tf.train.Example(features = tf.train.Features(
        feature = 
            'label' : tf.train.Feature(int64_list = tf.train.Int64List(value = [i])),
            'img_raw' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
        
    ))
    writer.write(example.SerializeToString())

writer.close()

由于存在于内存中的对象都是暂时的,无法长期驻存,为了把对象的状态保持下来,这时需要把对象写入到磁盘或者其他介质中,这个过程就叫做序列化。不序列化则无法保存

解析tfrecors

#encoding: utf-8
import tensorflow as tf
import numpy as np
from PIL import Image


if __name__=='__main__':
    tfrecords_filename = "test_tf.tfrecords"
    filename_queue = tf.train.string_input_producer([tfrecords_filename],) #读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features=
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       )  #取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'],tf.int64)
    image = tf.reshape(image, [7,30])
    label = tf.cast(features['label'], tf.int64)
    image, label = tf.train.shuffle_batch([image, label],
                                               batch_size= 2, capacity=200, min_after_dequeue=10)

    with tf.Session() as sess: #开始一个会话
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        coord=tf.train.Coordinator()
        threads= tf.train.start_queue_runners(coord=coord)
        for i in range(20):
            example, l = sess.run([image,label])#在会话中取出image和label
            img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
            img.save('./'+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
            print(example, l)

        coord.request_stop()
        coord.join(threads)

Note: 这里使用shuffle_batch_size就出问题了,暂不清楚为什么。这样的话,以上代码不知道是否数据与标签统一,看晚上的做法都是这样做的
对于以上问题,经过实验发现,并不会导致数据与标签不一致

以上是关于Tensorflow : 读取数据的三种方式及tfrecord的使用的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow 2.X中构建模型的三种方式:Sequential, Functional, Subclassing

Tensorflow机器学习入门——读取数据

tensorflow数据加载方式

tensorflow API _ 3 (tf.train.polynomial_decay)

Tensorflow数据读取方式总结

Tensorflow数据读取方式总结