用tensorflow创建tfrecords格式的数据集
Posted lyf98
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用tensorflow创建tfrecords格式的数据集相关的知识,希望对你有一定的参考价值。
下面的代码是生成一个每个图片大小是227*227*1的tfrecord文件,label是这个类别的英文名。
原图片是256*256*3RGB型的.jpg文件,在制作数据集的时候由于对图片的颜色没有要求,所以为了节省空间,进行了灰度化处理。
import tensorflow as tf import os import sys from PIL import Image import numpy as np # 数据集路径 TRAIN_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/train/" TEST_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/test/" # tfrecord文件存放路径 TFRECORD_DIR = "E:/python文件/tensorflow_learn/MyNet/images/" # 类型名 classes = {"apple_scab", "black_rot", "cedar_apple_rust", "healthy"} # 判断tfrecord文件是否存在 def _dataset_exists(tfrecord_dir): for split_name in [‘train‘, ‘test‘]: # 产生test.tfrecords和 train.tfrecords文件路径 output_filename = os.path.join(tfrecord_dir, split_name+‘.tfrecords‘) if not tf.gfile.Exists(output_filename): return False return True def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # 获取该类别的所有文件 def _get_filenames_and_classes(dataset_dir): photo_filename = [] for filename in os.listdir(dataset_dir): # 获取文件路径 path = os.path.join(dataset_dir, filename) photo_filename.append(path) return photo_filename # 把数据转换为TFRecord格式 def _convert_dataset(split_name, dataset_dir): assert split_name in [‘train‘, ‘test‘] with tf.Session() as sess: output_filename = os.path.join(TFRECORD_DIR, split_name+‘.tfrecords‘) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: for index, name in enumerate(classes): if split_name == ‘train‘: class_path = TRAIN_DATASET_DIR + name + ‘/‘ else: class_path = TEST_DATASET_DIR + name + ‘/‘ filenames = _get_filenames_and_classes(class_path) for i, img_name in enumerate(filenames): sys.stdout.write(‘ >>%s %s Convering image: %d/%d‘ % (split_name, name, i+1, len(filenames))) print(str(img_name)) sys.stdout.flush() image_data = Image.open(img_name) image_data = image_data.resize((227, 227)) image_data = np.array(image_data.convert(‘L‘)) # 图片灰度化处理 img_raw = image_data.tobytes() example = tf.train.Example( features=tf.train.Features( feature={ ‘img_raw‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), ‘label‘: tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), } ) ) tfrecord_writer.write(example.SerializeToString()) # tfrecord_writer.close() # 判断tfrecord文件是否存在 if _dataset_exists(TFRECORD_DIR): print("文件已存在") else: # 数据转换 _convert_dataset(‘test‘, TEST_DATASET_DIR) _convert_dataset(‘train‘, TRAIN_DATASET_DIR)
print(‘生成tfrecord文件!‘)
以上是关于用tensorflow创建tfrecords格式的数据集的主要内容,如果未能解决你的问题,请参考以下文章
目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
Tensorflow学习教程------tfrecords数据格式生成与读取