tensorflowxun训练自己的数据集之从tfrecords读取数据

Posted 康小武

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflowxun训练自己的数据集之从tfrecords读取数据相关的知识,希望对你有一定的参考价值。

  当训练数据量较小时,采用直接读取文件的方式,当训练数据量非常大时,直接读取文件的方式太耗内存,这时应采用高效的读取方法,读取tfrecords文件,这其实是一种二进制文件。tensorflow为其内置了各种存储和读取的函数,方便调用。

  不知道为啥,从tfrecords中读取数据用于训练时,收敛得更快,更平稳。上面两个图是使用tfrecords的准确率和loss值变化,下面是直接读取文件的准确率和loss值变化。

 

 

1 生成记录样本的记录文件

 1 root_dir = os.getcwd()
 2 
 3 def getTrianList():
 4     with open("train.txt","w") as f:
 5         for file in os.listdir(root_dir+\'\\\\dataSet\'):
 6             for picFile in os.listdir(root_dir+"\\\\dataSet\\\\"+file):
 7                 f.write("dataSet/"+file+"/"+picFile+" "+file+"\\n")
 8                 print(picFile)
 9 if __name__=="__main__":
10     getTrianList()

  将样本文件路径和标签统一记录到一个txt中,后面生成tfrecords文件就是通过读取这些信息。

  

  注意文件路径和标签之间采用空格,不要使用制表符。

2 读取txt存于数组中

 

1 def load_file(example_list_file):
2     lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[(\'col1\', \'S120\'), (\'col2\', \'i8\')])
3     examples = []
4     labels = []
5     for example,label in lines:
6         examples.append(example)
7         labels.append(label)
8     #convert to numpy array
9     return np.asarray(examples),np.asarray(labels),len(lines)

  这段代码主要用来读取第1步生成的txt,将文件路径和标签存于数组中

3 读取图片

1 def extract_image(filename,height,width):
2     print(filename)
3     image = cv2.imread(filename)
4     image = cv2.resize(image,(height,width))
5     b,g,r = cv2.split(image)
6     rgb_image = cv2.merge([r,g,b])
7     return rgb_image

  使用cv2读取图片文件

4 转化为tfrecords文件

 1 def trans2tfRecord(trainFile,name,output_dir,height,width):
 2     if not os.path.exists(output_dir) or os.path.isfile(output_dir):
 3         os.makedirs(output_dir)
 4     _examples,_labels,examples_num = load_file(train_file)
 5     filename = name + \'.tfrecords\'
 6     writer = tf.python_io.TFRecordWriter(filename)
 7     for i,[example,label] in enumerate(zip(_examples,_labels)):
 8         print("NO{}".format(i))
 9         #need to convert the example(bytes) to utf-8
10         example = example.decode("UTF-8")
11         image = extract_image(example,height,width)
12         image_raw = image.tostring()
13         example = tf.train.Example(features=tf.train.Features(feature={
14                 \'image_raw\':_bytes_feature(image_raw),
15                 \'height\':_int64_feature(image.shape[0]),
16                  \'width\': _int64_feature(32),  
17                 \'depth\': _int64_feature(32),  
18                  \'label\': _int64_feature(label)                        
19                 }))
20         writer.write(example.SerializeToString())
21     writer.close()
1 def _int64_feature(value):  
2     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
3   
4 def _bytes_feature(value):  
5     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  

5 从tfrecords中读取训练数据

 1 def read_tfRecord(file_tfRecord):
 2     queue = tf.train.string_input_producer([file_tfRecord])
 3     reader = tf.TFRecordReader()
 4     _,serialized_example = reader.read(queue)
 5     features = tf.parse_single_example(
 6             serialized_example,
 7             features={
 8           \'image_raw\': tf.FixedLenFeature([], tf.string),  
 9           \'height\': tf.FixedLenFeature([], tf.int64), 
10           \'width\':tf.FixedLenFeature([], tf.int64),
11           \'depth\': tf.FixedLenFeature([], tf.int64),  
12           \'label\': tf.FixedLenFeature([], tf.int64)  
13                     }
14             )
15     image = tf.decode_raw(features[\'image_raw\'],tf.uint8)
16     #height = tf.cast(features[\'height\'], tf.int64)
17     #width = tf.cast(features[\'width\'], tf.int64)
18     image = tf.reshape(image,[32,32,3])
19     image = tf.cast(image, tf.float32)
20     image = tf.image.per_image_standardization(image)
21     label = tf.cast(features[\'label\'], tf.int64)
22     print(image,label)
23     return image,label

  从tfrecords文件中读取image和label,训练的时候,直接使用tf.train.batch函数生成用于训练的batch即可。

1 image_batches,label_batches = tf.train.batch([image, label], batch_size=16, capacity=20)

  其余的部分跟之前的训练步骤一样。

以上是关于tensorflowxun训练自己的数据集之从tfrecords读取数据的主要内容,如果未能解决你的问题,请参考以下文章

物体检测之从RCNN到Faster RCNN

物体检测之从RCNN到Faster RCNN

使用@tf.function 进行自定义张量流训练的内存泄漏

使用 TF-slim 训练的模型与 python 推理完美配合,但使用 C++ 给出完全错误的结果

opencv之从视频帧中截取图片

使用Tensorflow Object Detection API训练自己的数据,并使用编译成功的模型进行识别