TFRecord 的使用

Posted 血影雪梦

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TFRecord 的使用相关的知识,希望对你有一定的参考价值。

什么是 TFRecord 

            PS:这段内容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html

            一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
            从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocolbuffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 

 

代码

            adjust_pic.py

                单纯的转换图片大小

 

# -*- coding: utf-8 -*-

import tensorflow as tf

def resize(img_data, width, high, method=0):
    return tf.image.resize_images(img_data,[width, high], method)


                pic2tfrecords.py

                将图片保存成TFRecord

 

# -*- coding: utf-8 -*-
# 将图片保存成 TFRecord
import os.path
import matplotlib.image as mpimg
import tensorflow as tf
import adjust_pic as ap
from PIL import Image


SAVE_PATH = 'data/dataset.tfrecords'


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def load_data(datafile, width, high, method=0, save=False):
    train_list = open(datafile,'r')
    # 准备一个 writer 用来写 TFRecord 文件
    writer = tf.python_io.TFRecordWriter(SAVE_PATH)

    with tf.Session() as sess:
        for line in train_list:
            # 获得图片的路径和类型
            tmp = line.strip().split(' ')
            img_path = tmp[0]
            label = int(tmp[1])

            # 读取图片
            image = tf.gfile.FastGFile(img_path, 'r').read()
            # 解码图片(如果是 png 格式就使用 decode_png)
            image = tf.image.decode_jpeg(image)
            # 转换数据类型
            # 因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            # 既然都将图片保存成 TFRecord 了,那就先把图片转换成希望的大小吧
            image = ap.resize(image, width, high)
            # 执行 op: image
            image = sess.run(image)
            
            # 将其图片矩阵转换成 string
            image_raw = image.tostring()
            # 将数据整理成 TFRecord 需要的数据结构
            example = tf.train.Example(features=tf.train.Features(feature=
                'image_raw': _bytes_feature(image_raw),
                'label': _int64_feature(label),
                ))

            # 写 TFRecord
            writer.write(example.SerializeToString())

    writer.close()


load_data('train_list.txt_bak', 224, 224)

                tfrecords2data.py

                从TFRecord中读取并保存成图片

# -*- coding: utf-8 -*-
# 从 TFRecord 中读取并保存图片
import tensorflow as tf
import numpy as np


SAVE_PATH = 'data/dataset.tfrecords'


def load_data(width, high):
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([SAVE_PATH])

    # 从 TFRecord 读取内容并保存到 serialized_example 中
    _, serialized_example = reader.read(filename_queue)
    # 读取 serialized_example 的格式
    features = tf.parse_single_example(
        serialized_example,
        features=
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        )

    # 解析从 serialized_example 读取到的内容
    images = tf.decode_raw(features['image_raw'], tf.uint8)
    labels = tf.cast(features['label'], tf.int64)

    with tf.Session() as sess:
        # 启动多线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 因为我这里只有 2 张图片,所以下面循环 2 次
        for i in range(2):
            # 获取一张图片和其对应的类型
            label, image = sess.run([labels, images])
            # 这里特别说明下:
            #   因为要想把图片保存成 TFRecord,那就必须先将图片矩阵转换成 string,即:
            #       pic2tfrecords.py 中 image_raw = image.tostring() 这行
            #   所以这里需要执行下面这行将 string 转换回来,否则会无法 reshape 成图片矩阵,请看下面的小例子:
            #       a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩阵
            #       b = a.tostring()
            #       # 下面这行的输出是 32,即: 2*2 之后还要再乘 8
            #       # 如果 tostring 之后的长度是 2*2=4 的话,那可以将 b 直接 reshape([2, 2]),但现在的长度是 2*2*8 = 32,所以无法直接 reshape
            #       # 同理如果你的图片是 500*500*3 的话,那 tostring() 之后的长度是 500*500*3 后再乘上一个数
            #       print len(b)
            #
            #   但在网上有很多提供的代码里都没有下面这一行,你们那真的能 reshape ?
            image = np.fromstring(image, dtype=np.float32)
            # reshape 成图片矩阵
            image = tf.reshape(image, [224, 224, 3])
            # 因为要保存图片,所以将其转换成 uint8
            image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
            # 按照 jpeg 格式编码
            image = tf.image.encode_jpeg(image)
            # 保存图片
            with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:
                f.write(sess.run(image))


load_data(224, 224)

train_list.txt_bak 中的内容如下:

image_1093.jpg 13
image_0805.jpg 10



以上是关于TFRecord 的使用的主要内容,如果未能解决你的问题,请参考以下文章

TFRecord 的使用

TFRecord文件

Spark-TFRecord:Apache Spark与TensorFlow TFRecord互操作示例

使用 tfslim 解码 tfrecord

为对象检测任务创建 tfrecord

TFRecord读写简介+Demo