读取TFRecord文件报错
Posted yangxiaoling
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了读取TFRecord文件报错相关的知识,希望对你有一定的参考价值。
读取保存有多个样例的TFRecord文件时报错:
InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400 [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]
从报错信息来看,与reshape有关。
先生成保存有两张图片的TFRecord文件
1 #!coding:utf8 2 3 import tensorflow as tf 4 import numpy as np 5 from PIL import Image 6 7 INPUT_DATA = [‘/home/error/tt/cat.jpg‘, ‘/home/error/tt/5605502523_05acb00ae7_n.jpg‘] # 输入文件 8 OUTPUT_DATA = ‘/home/error/tt.tfrecords‘ # 输出文件 9 10 def _int64_feature(value): 11 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 12 13 14 def _bytes_feature(value): 15 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 16 17 18 def create_image_lists(sess): 19 writer1 = tf.python_io.TFRecordWriter(OUTPUT_DATA) 20 21 # 处理图片数据 22 for f in INPUT_DATA: 23 print(f) 24 image = Image.open(f) 25 image_value = np.asarray(image, np.uint8) 26 27 height, width, channles = image_value.shape 28 print(height, width, channles, height*width*channles) 29 label = 0 30 31 example = tf.train.Example(features=tf.train.Features(feature={ 32 ‘name‘: _bytes_feature(f.encode(‘utf8‘)), 33 ‘image‘: _bytes_feature(image_value.tostring()), 34 ‘label‘: _int64_feature(label), 35 ‘height‘: _int64_feature(height), 36 ‘width‘: _int64_feature(width), 37 ‘channels‘: _int64_feature(channles) 38 })) 39 serialized_example = example.SerializeToString() 40 writer1.write(serialized_example) 41 42 writer1.close() 43 44 45 with tf.Session() as sess: 46 create_image_lists(sess) 47 48 49 # /home/error/tt/cat.jpg 50 # 1797 2673 3 14410143 51 # /home/error/tt/5605502523_05acb00ae7_n.jpg 52 # 240 320 3 230400
错误读取:
1 #!coding:utf8 2 import tensorflow as tf 3 import matplotlib.pyplot as plt 4 import numpy as np 5 6 OUTPUT_DATA = ‘/home/error/tt.tfrecords‘ # 输出文件 7 train_queue = tf.train.string_input_producer([OUTPUT_DATA]) 8 9 10 def read_file(file_queue, sess): 11 reader = tf.TFRecordReader() 12 _, serialized_example = reader.read(file_queue) 13 features = tf.parse_single_example( 14 serialized_example, 15 features={ 16 ‘image‘: tf.FixedLenFeature([], tf.string), 17 ‘label‘: tf.FixedLenFeature([], tf.int64), 18 ‘height‘: tf.FixedLenFeature([], tf.int64), 19 ‘width‘: tf.FixedLenFeature([], tf.int64), 20 ‘channels‘: tf.FixedLenFeature([], tf.int64), 21 }) 22 23 image, label = features[‘image‘], features[‘label‘] 24 height, width = features[‘height‘], features[‘width‘] 25 channels = features[‘channels‘] 26 decoded_image = tf.decode_raw(image, tf.uint8) 27 28 29 print(sess.run(decoded_image).shape) # (14410143,) 30 # 为啥打印时,会影响reshape?? sess.run造成的,因为可视化时也会出现这个问题,但不知道原因; 31 # 原因是执行sess.run()时会从队列中重新取一个样例导致样例不同。 32 33 height_val, width_val, channels_val = sess.run([height, width, channels]) 34 print(height_val, width_val, channels_val, height_val*width_val*channels_val) # 240 320 3 230400 35 reshaped_decoded_image = tf.reshape(decoded_image, [height_val, width_val, channels_val]) 36 # print(reshaped_decoded_image.shape) # (240, 320, 3) 37 38 reshaped_decoded_image_val = sess.run(reshaped_decoded_image) # reshape时不会报错,当执行运算时才会报错 39 # plt.imshow(sess.run(reshaped_decoded_image)) 40 # plt.show() 41 42 with tf.Session() as sess: 43 tf.local_variables_initializer().run() 44 45 coord = tf.train.Coordinator() 46 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 47 48 for _ in range(2): 49 read_file(train_queue, sess) 50 51 coord.request_stop() 52 coord.join(threads) 53 54 # 从保存有多个样例的tfrecord文件中读取数据会报错 55 # 报错日志: 56 # InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400 57 # [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]
正确读取:
#!coding:utf8 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np OUTPUT_DATA = ‘/home/error/tt.tfrecords‘ # 输出文件 train_queue = tf.train.string_input_producer([OUTPUT_DATA]) def read_file(file_queue, sess): reader = tf.TFRecordReader() _, serialized_example = reader.read(file_queue) features = tf.parse_single_example( serialized_example, features={ ‘image‘: tf.FixedLenFeature([], tf.string), ‘label‘: tf.FixedLenFeature([], tf.int64), ‘height‘: tf.FixedLenFeature([], tf.int64), ‘width‘: tf.FixedLenFeature([], tf.int64), ‘channels‘: tf.FixedLenFeature([], tf.int64), }) image, label = features[‘image‘], features[‘label‘] height, width = features[‘height‘], features[‘width‘] channels = features[‘channels‘] decoded_image = tf.decode_raw(image, tf.uint8) return decoded_image, label, height, width, channels # decoded_image_val, label_val, height_val, width_val, channels_val = sess.run([decoded_image, label, height, width, channels]) # # print(decoded_image.shape) # (230400,) # # 为啥打印时,会影响reshape?? sess.run造成的,因为可视化时也会出现这个问题,但不知道原因; 原因是会从队列中重新取一个样例导致样例不同。 with tf.Session() as sess: tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) decoded_image, label, height, width, channels = read_file(train_queue, sess) for _ in range(6): # 迭代`sess.run`本身,能够保证在同一时刻处理的是同一个样例 decoded_image_val, label_val, height_val, width_val, channels_val = sess.run([decoded_image, label, height, width, channels]) print(height_val) reshaped_decoded_image = np.reshape(decoded_image_val, [height_val, width_val, channels_val]) plt.imshow(reshaped_decoded_image) plt.show() coord.request_stop() coord.join(threads)
InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400
[[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]
以上是关于读取TFRecord文件报错的主要内容,如果未能解决你的问题,请参考以下文章
吴裕雄--天生自然 pythonTensorFlow图形数据处理:读取MNIST手写图片数据写入的TFRecord文件
tensorflow-TFRecord报错ValueError: Protocol message Feature has no "feature" field.