tensorflow数据加载方式
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow数据加载方式相关的知识,希望对你有一定的参考价值。
tensorflow当前具有三种读取数据的方式:1.预加载(preloaded):在构建tensorflow流图时直接定义常量数据,由于数据是直接镶嵌在流图中,所以当数据量很大时将占用大量内存
import tensorflow as tf
a = tf.constant([1,2,3],name=‘input_a‘)
b = tf.constant([4,5,6],name=‘input_b‘)
c = tf.add(a,b,name=‘sums‘)
sess = tf.Session()
x = sess.run(c)
print(x)
2.填充(feeding):将python产生的数据直接填充到后端,这种方式同样存在数据量大时消耗内存的问题,同时数据类型转换也会增加一些开销
import tensorflow as tf
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
c = tf.add(a,b)
p_a = [1,2,3]
p_b = [4,5,6]
with tf.Session() as sess:
print(sess.run(c, feed_dict={a:p_a, b:p_b}))
3.从文件读取(reading from file):相较于上面两种,这种方式处理量大的数据具有很大优势。tensorflow在从文件中读取数据时主要分两步:
(1)将数据写入TFRecords二进制文件;
‘‘‘创建转换函数,将数据填入到tf.train.Example协议缓冲区中,同时将缓冲区序列化为字符串,
再通过tf.python_io.TFRecordWriter写入TFRecords文件‘‘‘
import os
import tensorflow as tf
def int64_feature(data):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
def bytes_feature(data):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
def convert_tfrecords(data, name):
images = data.images
labels = data.labels
num_examples = data.num_examples
if images.shape[0] != num_examples:
raise ValueError(u‘图片数量与标签数量不一致,分别为%d和%d‘ %(images.shape[0],num_examples))
rows = images.shape[1]
width = images.shape[2]
depth = images.shape[3]
filename = os.path.join(os.path.dirname(__file__), name + ‘.tfrecores‘)
writer = tf.python_io.TFRecoredWriter(filename)
for i in range(num_examples):
image_raw = images[i].tostring()
example = tf.train.Example(features = tf.train.Features(feature = {
‘height‘: int64_feature(rows), ‘width‘:int64_feature(width),
‘depth‘:int64_feature(depth),‘label‘:int64_feature(labels),
‘image_raw‘:bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
(2)使用队列从二进制文件中读取数据。
以上是关于tensorflow数据加载方式的主要内容,如果未能解决你的问题,请参考以下文章