tensorflow读取数据之CSV格式

Posted 终有扬眉吐气天

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow读取数据之CSV格式相关的知识,希望对你有一定的参考价值。

tensorflow要想用起来,首先自己得搞定数据输入。官方文档中介绍了几种,1.一次性从内存中读取数据到矩阵中,直接输入;2.从文件中边读边输入,而且已经给设计好了多线程读写模型;3.把网络或者内存中的数据转化为tensorflow的专用格式tfRecord,存文件后再读取。

 

其中,从文件中边读边输入,官方文档举例是用的CSV格式文件。我在网上找了一份代码,修改了一下,因为他的比较简略,我就补充一下遇到的问题

先贴代码

 

#coding=utf-8import tensorflow as tf

import numpy as np

defreadMyFileFormat(fileNameQueue):  

reader = tf.TextLineReader()  

key, value = reader.read(fileNameQueue)  

record_defaults = [[1], [1], [1]]  

col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults)  

features = tf.pack([col1, col2])  

label = col3  

return features, label

definputPipeLine(fileNames = ["1.csv","2.csv"], batchSize =4, numEpochs = None):  

fileNameQueue = tf.train.string_input_producer(fileNames, num_epochs = numEpochs)  

example, label = readMyFileFormat(fileNameQueue)  

min_after_dequeue =8  

capacity = min_after_dequeue +3 * batchSize  

exampleBatch, labelBatch = tf.train.shuffle_batch([example, label], batch_size = batchSize, num_threads = 3, capacity = cap acity, min_after_dequeue = min_after_dequeue)  

return exampleBatch, labelBatch

featureBatch, labelBatch = inputPipeLine(["1.csv","2.csv"], batchSize = 4)

with tf.Session() as sess: # Start populating the filename queue.coord = tf.train.Coordinator()  

threads = tf.train.start_queue_runners(coord=coord)  

# Retrieve a single instance:try:#while not coord.should_stop():

whileTrue:  

example, label = sess.run([featureBatch, labelBatch])print example  

except tf.errors.OutOfRangeError:  

print‘Done reading‘  

finally:  

coord.request_stop()  


coord.join(threads)  

sess.close()


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

其中,record_defaults = [[1], [1], [1]] ,是用于指定矩阵格式以及数据类型的,CSV文件中的矩阵,是NXM的,则此处为1XM,[1]中的1 用于指定数据类型,比如矩阵中如果有小数,则为float,[1]应该变为[1.0]。


col1, col2, col3 = tf.decode_csv(value, record_defaults = record_defaults) , 矩阵中有几列,这里就要写几个参数,比如5列,就要写到col5,不管你到底用多少。否则报错。


tf.pack([col1, col2]) ,好像要求col1与col2是同一数据类型,否则报错。

我的测试数据

 

 

-0.76 15.67 -0.12 15.67
-0.48 12.52 -0.06 12.51
1.33 9.11 0.12 9.1
-0.88 20.35 -0.18 20.36
-0.25 3.99 -0.01 3.99
-0.87 26.25 -0.23 26.25
-1.03 2.87 -0.03 2.87
-0.51 7.81 -0.04 7.81
-1.57 14.46 -0.23 14.46
-0.1 10.02 -0.01 10.02
-0.56 8.92 -0.05 8.92
-1.2 4.1 -0.05 4.1
-0.77 5.15 -0.04 5.15
-0.88 4.48 -0.04 4.48
-2.7 10.82 -0.3 10.82
-1.23 2.4 -0.03 2.4
-0.77 5.16 -0.04 5.15
-0.81 6.15 -0.05 6.15
-0.6 5.01 -0.03 5
-1.25 4.75 -0.06 4.75
-2.53 7.31 -0.19 7.3
-1.15 16.39 -0.19 16.39
-1.7 5.19 -0.09 5.18
-0.62 3.23 -0.02 3.22
-0.74 17.43 -0.13 17.41
-0.77 15.41 -0.12 15.41
0 47 0 47.01
0.25 3.98 0.01 3.98
-1.1 9.01 -0.1 9.01
-1.02 3.87 -0.04 3.87

以上是关于tensorflow读取数据之CSV格式的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow读取CSV数据

TensorFlow读取CSV数据(批量)

Tensorflow csv数据集使用情况

如何读取 csv 文件并将其应用于 tensorflow

tensorflow 从数据库中读取数据

读取自己的数据集