吴裕雄 python 神经网络——TensorFlow TFRecord样例程序
Posted tszr
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了吴裕雄 python 神经网络——TensorFlow TFRecord样例程序相关的知识,希望对你有一定的参考价值。
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 将数据转化为tf.train.Example格式。 def _make_example(pixels, label, image): image_raw = image.tostring() example = tf.train.Example(features=tf.train.Features(feature={ ‘pixels‘: _int64_feature(pixels), ‘label‘: _int64_feature(np.argmax(label)), ‘image_raw‘: _bytes_feature(image_raw) })) return example # 读取mnist训练数据。 mnist = input_data.read_data_sets("E:\\\\MNIST_data\\\\",dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels pixels = images.shape[1] num_examples = mnist.train.num_examples # 输出包含训练数据的TFRecord文件。 with tf.python_io.TFRecordWriter("E:\\\\MNIST_data\\\\output.tfrecords") as writer: for index in range(num_examples): example = _make_example(pixels, labels[index], images[index]) writer.write(example.SerializeToString()) print("TFRecord训练文件已保存。") # 读取mnist测试数据。 images_test = mnist.test.images labels_test = mnist.test.labels pixels_test = images_test.shape[1] num_examples_test = mnist.test.num_examples # 输出包含测试数据的TFRecord文件。 with tf.python_io.TFRecordWriter("E:\\\\MNIST_data\\\\output_test.tfrecords") as writer: for index in range(num_examples_test): example = _make_example(pixels_test, labels_test[index], images_test[index]) writer.write(example.SerializeToString()) print("TFRecord测试文件已保存。")
# 读取文件。 reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer(["E:\\\\MNIST_data\\\\output.tfrecords"]) _,serialized_example = reader.read(filename_queue) # 解析读取的样例。 features = tf.parse_single_example( serialized_example, features={ ‘image_raw‘:tf.FixedLenFeature([],tf.string), ‘pixels‘:tf.FixedLenFeature([],tf.int64), ‘label‘:tf.FixedLenFeature([],tf.int64) }) images = tf.decode_raw(features[‘image_raw‘],tf.uint8) labels = tf.cast(features[‘label‘],tf.int32) pixels = tf.cast(features[‘pixels‘],tf.int32) sess = tf.Session() # 启动多线程处理输入数据。 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10): image, label, pixel = sess.run([images, labels, pixels])
以上是关于吴裕雄 python 神经网络——TensorFlow TFRecord样例程序的主要内容,如果未能解决你的问题,请参考以下文章
吴裕雄 python 神经网络——TensorFlow 输入文件队列
吴裕雄 python 机器学习——人工神经网络感知机学习算法的应用
吴裕雄 python 神经网络——TensorFlow 队列操作
吴裕雄 python 神经网络——TensorFlow 数据集高层操作