二月五号博客
Posted goubb
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了二月五号博客相关的知识,希望对你有一定的参考价值。
今天学了TensorFlow文件读取操作
一,读取图片文件
def read_picture(): """ 读取狗图片案例 :return: """ # 1、构造文件名队列 # 构造文件名列表 filename_list = os.listdir("./dog") # 给文件名加上路径 file_list = [os.path.join("./dog/", i) for i in filename_list] # print("file_list: ", file_list) # print("filename_list: ", filename_list) file_queue = tf.train.string_input_producer(file_list) # 2、读取与解码 # 读取 reader = tf.WholeFileReader() key, value = reader.read(file_queue) print("key: ", key) print("value: ", value) # 解码 image_decoded = tf.image.decode_jpeg(value) print("image_decoded: ", image_decoded) # 将图片缩放到同一个大小 image_resized = tf.image.resize_images(image_decoded, [200, 200]) print("image_resized_before: ", image_resized) # 更新静态形状 image_resized.set_shape([200, 200, 3]) print("image_resized_after: ", image_resized) # 3、批处理队列 image_batch = tf.train.batch([image_resized], batch_size=100, num_threads=2, capacity=100) print("image_batch: ", image_batch) # 开启会话 with tf.Session() as sess: # 开启线程 # 构造线程协调器 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 运行 filename, sample, image, n_image = sess.run([key, value, image_resized, image_batch]) print("filename: ", filename) print("sample: ", sample) print("image: ", image) print("n_image: ", n_image) coord.request_stop() coord.join(threads) return None
二,读取二进制文件
class Cifar(): def __init__(self): # 设置图像大小 self.height = 32 self.width = 32 self.channel = 3 # 设置图像字节数 self.image = self.height * self.width * self.channel self.label = 1 self.sample = self.image + self.label def read_binary(self): """ 读取二进制文件 :return: """ # 1、构造文件名队列 filename_list = os.listdir("./cifar-10-batches-bin") # print("filename_list: ", filename_list) file_list = [os.path.join("./cifar-10-batches-bin/", i) for i in filename_list if i[-3:]=="bin"] # print("file_list: ", file_list) file_queue = tf.train.string_input_producer(file_list) # 2、读取与解码 # 读取 reader = tf.FixedLengthRecordReader(self.sample) # key文件名 value样本 key, value = reader.read(file_queue) # 解码 image_decoded = tf.decode_raw(value, tf.uint8) print("image_decoded: ", image_decoded) # 切片操作 label = tf.slice(image_decoded, [0], [self.label]) image = tf.slice(image_decoded, [self.label], [self.image]) print("label: ", label) print("image: ", image) # 调整图像的形状 image_reshaped = tf.reshape(image, [self.channel, self.height, self.width]) print("image_reshaped: ", image_reshaped) # 三维数组的转置 image_transposed = tf.transpose(image_reshaped, [1, 2, 0]) print("image_transposed: ", image_transposed) # 3、构造批处理队列 image_batch, label_batch = tf.train.batch([image_transposed, label], batch_size=100, num_threads=2, capacity=100) # 开启会话 with tf.Session() as sess: # 开启线程 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) label_value, image_value = sess.run([label_batch, image_batch]) print("label_value: ", label_value) print("image: ", image_value) coord.request_stop() coord.join(threads) return image_value, label_value
三,读取TFRecords文件
def read_tfrecords(self): """ 读取TFRecords文件 :return: """ # 1、构造文件名队列 file_queue = tf.train.string_input_producer(["cifar10.tfrecords"]) # 2、读取与解码 # 读取 reader = tf.TFRecordReader() key, value = reader.read(file_queue) # 解析example feature = tf.parse_single_example(value, features={ "image": tf.FixedLenFeature([], tf.string), "label": tf.FixedLenFeature([], tf.int64) }) image = feature["image"] label = feature["label"] print("read_tf_image: ", image) print("read_tf_label: ", label) # 解码 image_decoded = tf.decode_raw(image, tf.uint8) print("image_decoded: ", image_decoded) # 图像形状调整 image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel]) print("image_reshaped: ", image_reshaped) # 3、构造批处理队列 image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100) print("image_batch: ", image_batch) print("label_batch: ", label_batch) # 开启会话 with tf.Session() as sess: # 开启线程 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) image_value, label_value = sess.run([image_batch, label_batch]) print("image_value: ", image_value) print("label_value: ", label_value) # 回收资源 coord.request_stop() coord.join(threads) return None
以上是关于二月五号博客的主要内容,如果未能解决你的问题,请参考以下文章