目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

Posted allen-rg

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练相关的知识,希望对你有一定的参考价值。

将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

 

import xml.etree.ElementTree as ET
import numpy as np
import os
import tensorflow as tf
from PIL import Image

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
           "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]


def convert(size, box):
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return [x, y, w, h]


def convert_annotation(image_id):
    in_file = open(‘F:/xml/%s.xml‘%(image_id))

    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find(‘size‘)
    w = int(size.find(‘width‘).text)
    h = int(size.find(‘height‘).text)
    bboxes = []
    for i, obj in enumerate(root.iter(‘object‘)):
        if i > 29:
            break
        difficult = obj.find(‘difficult‘).text
        cls = obj.find(‘name‘).text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find(‘bndbox‘)
        b = (float(xmlbox.find(‘xmin‘).text), float(xmlbox.find(‘xmax‘).text), float(xmlbox.find(‘ymin‘).text), float(xmlbox.find(‘ymax‘).text))
        bb = convert((w, h), b) + [cls_id]
        bboxes.extend(bb)
    if len(bboxes) < 30*5:
        bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))

    return np.array(bboxes, dtype=np.float32).flatten().tolist()

def convert_img(image_id):
    image = Image.open(‘F:/snow leopard/test_im/%s.jpg‘ % (image_id))
    resized_image = image.resize((416, 416), Image.BICUBIC)
    image_data = np.array(resized_image, dtype=‘float32‘)/255
    img_raw = image_data.tobytes()
    return img_raw

filename = os.path.join(‘test‘+‘.tfrecords‘)
writer = tf.python_io.TFRecordWriter(filename)
# image_ids = open(‘F:/snow leopard/test_im/%s.txt‘ % (
#     year, year, image_set)).read().strip().split()

image_ids = os.listdir(‘F:/snow leopard/test_im/‘)
# print(filename)
for image_id in image_ids:
    print (image_id)
    image_id = image_id.split(‘.‘)[0]
    print (image_id)

    xywhc = convert_annotation(image_id)
    img_raw = convert_img(image_id)

    example = tf.train.Example(features=tf.train.Features(feature={
        ‘xywhc‘:
                tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
        ‘img‘:
                tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        }))
    writer.write(example.SerializeToString())
writer.close()

  

Python读取文件夹下图片的两种方法:

import os
imagelist = os.listdir(‘./images/‘)      #读取images文件夹下所有文件的名字

 

import glob
imagelist= sorted(glob.glob(‘./images/‘ + ‘frame_*.png‘))      #读取带有相同关键字的图片名字,比上一中方法好


参考:

https://blog.csdn.net/CV_YOU/article/details/80778392

https://github.com/raytroop/YOLOv3_tf


以上是关于目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练的主要内容,如果未能解决你的问题,请参考以下文章

[数据集][VOC][目标检测]输电线异物数据集目标检测可用yolo训练-4165张介绍

[数据集][VOC][目标检测]河道垃圾水面漂浮物数据集目标检测可用yolo训练-1304张介绍

[数据集][VOC][目标检测]翻越栏杆翻越防护栏数据集目标检测可用yolo训练-1035张介绍

目标检测:把标注文件txt格式转换为xml格式

目标检测小脚本:根据xml批量复制jpg图片

目标检测——标注图像(超详细步骤)