tf.train.string_input_producer()

Posted helloworld0604

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.train.string_input_producer()相关的知识,希望对你有一定的参考价值。

处理从文件中读数据

官方说明

技术分享图片

简单使用

示例中读取的是csv文件,如果要读tfrecord的文件,需要换成 tf.TFRecordReader

import tensorflow as tf
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    # Start populating the filename queue.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(12):
        # Retrieve a single instance:
        example, label = sess.run([features, col5])
        print(example, label)

    coord.request_stop()
    coord.join(threads)

运行结果:

技术分享图片

结合批处理

import tensorflow as tf
def read_my_file_format(filename_queue):
#     reader = tf.SomeReader()
    reader = tf.TextLineReader()
    key, record_string = reader.read(filename_queue)
#     example, label = tf.some_decoder(record_string)
    record_defaults = [[1], [1], [1], [1], [1]]
    col1, col2, col3, col4, col5 = tf.decode_csv(record_string, record_defaults=record_defaults)
#     processed_example = some_processing(example)
    features = tf.stack([col1, col2, col3, col4])
    return features, col5

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
    example, label = read_my_file_format(filename_queue)
    #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
                              min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

x,y = input_pipeline(["file0.csv", "file1.csv"],5,4)

sess = tf.Session()
sess.run([tf.global_variables_initializer(),tf.initialize_local_variables()])

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    print("in try")
    while not coord.should_stop():
        # Run training steps or whatever
        example, label = sess.run([x,y])
        print(example, label)
        print("ssss")
        
except tf.errors.OutOfRangeError:
    print (Done training -- epoch limit reached)
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

运行结果:

技术分享图片

以上是关于tf.train.string_input_producer()的主要内容,如果未能解决你的问题,请参考以下文章