tensorflow数据加载方式

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow数据加载方式相关的知识,希望对你有一定的参考价值。

tensorflow当前具有三种读取数据的方式:
1.预加载(preloaded):在构建tensorflow流图时直接定义常量数据,由于数据是直接镶嵌在流图中,所以当数据量很大时将占用大量内存

import tensorflow as tf
a = tf.constant([1,2,3],name=‘input_a‘)
b = tf.constant([4,5,6],name=‘input_b‘)
c = tf.add(a,b,name=‘sums‘)
sess = tf.Session()
x = sess.run(c)
print(x)

2.填充(feeding):将python产生的数据直接填充到后端,这种方式同样存在数据量大时消耗内存的问题,同时数据类型转换也会增加一些开销

import tensorflow as tf
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
c = tf.add(a,b)
p_a = [1,2,3]
p_b = [4,5,6]
with tf.Session() as sess:
    print(sess.run(c, feed_dict={a:p_a, b:p_b}))

3.从文件读取(reading from file):相较于上面两种,这种方式处理量大的数据具有很大优势。tensorflow在从文件中读取数据时主要分两步:
(1)将数据写入TFRecords二进制文件;

‘‘‘创建转换函数,将数据填入到tf.train.Example协议缓冲区中,同时将缓冲区序列化为字符串,
  再通过tf.python_io.TFRecordWriter写入TFRecords文件‘‘‘
import os
import tensorflow as tf
def int64_feature(data):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
def bytes_feature(data):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
def convert_tfrecords(data, name):
    images = data.images
    labels = data.labels
    num_examples = data.num_examples
    if images.shape[0] != num_examples:
        raise ValueError(u‘图片数量与标签数量不一致,分别为%d和%d‘ %(images.shape[0],num_examples))
    rows = images.shape[1]
    width = images.shape[2]
    depth = images.shape[3]
    filename = os.path.join(os.path.dirname(__file__), name + ‘.tfrecores‘)
    writer = tf.python_io.TFRecoredWriter(filename)
    for i in range(num_examples):
        image_raw = images[i].tostring()
        example = tf.train.Example(features = tf.train.Features(feature = {
                            ‘height‘: int64_feature(rows), ‘width‘:int64_feature(width),
                            ‘depth‘:int64_feature(depth),‘label‘:int64_feature(labels),
                            ‘image_raw‘:bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

(2)使用队列从二进制文件中读取数据。

以上是关于tensorflow数据加载方式的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow加载数据

手写数字识别——基于全连接层和MNIST数据集

tensorflow中卷积层输出特征尺寸计算和padding参数解析

重新加载时刷新片段

用于数据加载的 Android 活动/片段职责

如何在android中将json数据加载到片段中