读取TFRecord文件报错

Posted yangxiaoling

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了读取TFRecord文件报错相关的知识,希望对你有一定的参考价值。

读取保存有多个样例的TFRecord文件时报错:

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400
	 [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]

 从报错信息来看,与reshape有关。

 

先生成保存有两张图片的TFRecord文件

 1 #!coding:utf8
 2 
 3 import tensorflow as tf
 4 import numpy as np
 5 from PIL import Image
 6 
 7 INPUT_DATA = [/home/error/tt/cat.jpg, /home/error/tt/5605502523_05acb00ae7_n.jpg]  # 输入文件
 8 OUTPUT_DATA = /home/error/tt.tfrecords  # 输出文件
 9 
10 def _int64_feature(value):
11     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
12 
13 
14 def _bytes_feature(value):
15     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
16 
17 
18 def create_image_lists(sess):
19         writer1 = tf.python_io.TFRecordWriter(OUTPUT_DATA)
20 
21         # 处理图片数据
22         for f in INPUT_DATA:
23             print(f)
24             image = Image.open(f)
25             image_value = np.asarray(image, np.uint8)
26 
27             height, width, channles = image_value.shape
28             print(height, width, channles, height*width*channles)
29             label = 0
30 
31             example = tf.train.Example(features=tf.train.Features(feature={
32                 name: _bytes_feature(f.encode(utf8)),
33                 image: _bytes_feature(image_value.tostring()),
34                 label: _int64_feature(label),
35                 height: _int64_feature(height),
36                 width: _int64_feature(width),
37                 channels: _int64_feature(channles)
38             }))
39             serialized_example = example.SerializeToString()
40             writer1.write(serialized_example)
41 
42         writer1.close()
43 
44 
45 with tf.Session() as sess:
46     create_image_lists(sess)
47 
48 
49 # /home/error/tt/cat.jpg
50 # 1797 2673 3 14410143
51 # /home/error/tt/5605502523_05acb00ae7_n.jpg
52 # 240 320 3 230400

 

错误读取:

 1 #!coding:utf8
 2 import tensorflow as tf
 3 import matplotlib.pyplot as plt
 4 import numpy as np
 5 
 6 OUTPUT_DATA = /home/error/tt.tfrecords  # 输出文件
 7 train_queue = tf.train.string_input_producer([OUTPUT_DATA])
 8 
 9 
10 def read_file(file_queue, sess):
11     reader = tf.TFRecordReader()
12     _, serialized_example = reader.read(file_queue)
13     features = tf.parse_single_example(
14         serialized_example,
15         features={
16             image: tf.FixedLenFeature([], tf.string),
17             label: tf.FixedLenFeature([], tf.int64),
18             height: tf.FixedLenFeature([], tf.int64),
19             width: tf.FixedLenFeature([], tf.int64),
20             channels: tf.FixedLenFeature([], tf.int64),
21         })
22 
23     image, label = features[image], features[label]
24     height, width = features[height], features[width]
25     channels = features[channels]
26     decoded_image = tf.decode_raw(image, tf.uint8)
27 
28 
29     print(sess.run(decoded_image).shape)  # (14410143,)
30     #  为啥打印时,会影响reshape??  sess.run造成的,因为可视化时也会出现这个问题,但不知道原因;
31     # 原因是执行sess.run()时会从队列中重新取一个样例导致样例不同
32 
33     height_val, width_val, channels_val = sess.run([height, width, channels])
34     print(height_val, width_val, channels_val, height_val*width_val*channels_val)  # 240 320 3 230400
35     reshaped_decoded_image = tf.reshape(decoded_image, [height_val, width_val, channels_val])
36     # print(reshaped_decoded_image.shape)  # (240, 320, 3)
37 
38     reshaped_decoded_image_val = sess.run(reshaped_decoded_image)  # reshape时不会报错,当执行运算时才会报错
39     # plt.imshow(sess.run(reshaped_decoded_image))
40     # plt.show()
41 
42 with tf.Session() as sess:
43     tf.local_variables_initializer().run()
44 
45     coord = tf.train.Coordinator()
46     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
47 
48     for _ in range(2):
49         read_file(train_queue, sess)
50 
51     coord.request_stop()
52     coord.join(threads)
53 
54 # 从保存有多个样例的tfrecord文件中读取数据会报错
55 # 报错日志:
56 # InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400
57 #      [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]

 

正确读取:

#!coding:utf8
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

OUTPUT_DATA = /home/error/tt.tfrecords  # 输出文件
train_queue = tf.train.string_input_producer([OUTPUT_DATA])


def read_file(file_queue, sess):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            image: tf.FixedLenFeature([], tf.string),
            label: tf.FixedLenFeature([], tf.int64),
            height: tf.FixedLenFeature([], tf.int64),
            width: tf.FixedLenFeature([], tf.int64),
            channels: tf.FixedLenFeature([], tf.int64),
        })

    image, label = features[image], features[label]
    height, width = features[height], features[width]
    channels = features[channels]
    decoded_image = tf.decode_raw(image, tf.uint8)

    return decoded_image, label, height, width, channels

    # decoded_image_val, label_val, height_val, width_val, channels_val = sess.run([decoded_image, label, height, width, channels])
    # # print(decoded_image.shape)  # (230400,)
    # #  为啥打印时,会影响reshape??  sess.run造成的,因为可视化时也会出现这个问题,但不知道原因; 原因是会从队列中重新取一个样例导致样例不同。



with tf.Session() as sess:
    tf.local_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    decoded_image, label, height, width, channels = read_file(train_queue, sess)

    for _ in range(6):
        # 迭代`sess.run`本身,能够保证在同一时刻处理的是同一个样例
        decoded_image_val, label_val, height_val, width_val, channels_val = sess.run([decoded_image, label, height, width, channels])
        print(height_val)
        reshaped_decoded_image = np.reshape(decoded_image_val, [height_val, width_val, channels_val])
        plt.imshow(reshaped_decoded_image)
        plt.show()

    coord.request_stop()
    coord.join(threads)

 

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 14410143 values, but the requested shape has 230400
     [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DecodeRaw, Reshape/shape)]]




以上是关于读取TFRecord文件报错的主要内容,如果未能解决你的问题,请参考以下文章

TFRecord文件的读写

吴裕雄--天生自然 pythonTensorFlow图形数据处理:读取MNIST手写图片数据写入的TFRecord文件

tensorflow-TFRecord报错ValueError: Protocol message Feature has no "feature" field.

tensorflow的tfrecord操作代码与数据协议规范

tensorflow读取tfrecord数据集

tensorflow读取数据-tfrecord格式