标注数据读取与存储案例:xml读取本地文件存储到pkl

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了标注数据读取与存储案例:xml读取本地文件存储到pkl相关的知识,希望对你有一定的参考价值。

案例:xml读取本地文件存储到pkl

  • ElementTree工具使用,解析xml结构
  • 保存物体坐标结果以及类别
    • pickle工具导出

1. 解析结构

  • 导入
from xml.etree import ElementTree
  • 处理XML库
    • import xml.etree.ElementTree as ET
      • tree = et.parse(filename):形成树状结构
      • tree.getroot():获取树结构的根部分
      • root.find与findall()进行查询XML每个标签的内容.text

定义解析xml结构类,

class XmlProcess(object):

    def __init__(self, data_path):
        self.path_prefix = data_path
        self.num_classes = 8
        self.data = dict()

进行preprocess_xml处理

  • 解析基本信息
def preprocess_xml(self):
        # 找到文件名字
        filenames = os.listdir(self.path_prefix)
        for filename in filenames:
            # 1、XML解析根路径
            tree = ElementTree.parse(self.path_prefix + filename)
            root = tree.getroot()
            bounding_boxes = []
            one_hot_classes = []
            size_tree = root.find('size')
            width = float(size_tree.find('width').text)
            height = float(size_tree.find('height').text)
  • 获取每个对象的坐标
			# 每个图片标记的对象进行坐标获取
            for object_tree in root.findall('object'):
                for bounding_box in object_tree.iter('bndbox'):
                    xmin = float(bounding_box.find('xmin').text)/width
                    ymin = float(bounding_box.find('ymin').text)/height
                    xmax = float(bounding_box.find('xmax').text)/width
                    ymax = float(bounding_box.find('ymax').text)/height
                bounding_box = [xmin, ymin, xmax, ymax]
                bounding_boxes.append(bounding_box)
                class_name = object_tree.find('name').text
                # 将类别进行one_hot编码
                one_hot_class = self.on_hot(class_name)
                one_hot_classes.append(one_hot_class)
  • 获取图片名字,保存bounding_boxes以及类别one_hot编码信息
            image_name = root.find('filename').text
            bounding_boxes = np.asarray(bounding_boxes)
            one_hot_classes = np.asarray(one_hot_classes)
            # 存储图片标注的结果对应的名字,以及图片的标注数据(4个坐标以及onehot编码)
            image_data = np.hstack((bounding_boxes, one_hot_classes))
            self.data[image_name] = image_data

2. one_hot编码函数

    def on_hot(self, name):
        one_hot_vector = [0] * self.num_classes
        if name == 'clothes':
            one_hot_vector[0] = 1
        elif name == 'pants':
            one_hot_vector[1] = 1
        elif name == 'shoes':
            one_hot_vector[2] = 1
        elif name == 'watch':
            one_hot_vector[3] = 1
        elif name == 'phone':
            one_hot_vector[4] = 1
        elif name == 'audio':
            one_hot_vector[5] = 1
        elif name == 'computer':
            one_hot_vector[6] = 1
        elif name == 'books':
            one_hot_vector[7] = 1
        else:
            print('unknown label: %s' % name)
        return one_hot_vector

使用preprocess进行本地保存到pickle文件

if __name__ == '__main__':
    xp = XmlProcess('/Users/huxinghui/workspace/ml/detection/ssd_detection/ssd/datasets/commodity/Annotations/')
    xp.preprocess_xml()
    pickle.dump(xp.data, open('./commodity_gt.pkl', 'wb'))

3. 完整代码

from xml.etree import ElementTree as ET
import numpy as np
import os
import pickle


class XmlProcess(object):
    # 初始化(构造方法)
    def __init__(self, file_path):
        self.xml_path = file_path
        self.num_classes = 8
        self.data = {}

    def process_xml(self):
        """
        处理图片的标注信息,解析图片大小,图片中所有物体位置,类别
        存入序列化的pkl文件
        :return:
        """
        # 1. 找到路径对应的图片
        for filename in os.listdir(self.xml_path):
            et = ET.parse(self.xml_path + filename)
            root = et.getroot()

            # 获取图片基础属性
            # 获取size
            size = root.find('size')
            width = float(size.find('width').text)
            height = float(size.find('height').text)
            depth = float(size.find('depth').text)

            # 2. 对于每张图片,解析其中的多个物体
            bounding_boxes = []
            one_hots = []
            for object_tree in root.findall('object'):
                for res in object_tree.iter('bndbox'):
                    xmin = float(res.find('xmin').text) / width
                    ymin = float(res.find('ymin').text) / height
                    xmax = float(res.find('xmax').text) / width
                    ymax = float(res.find('ymax').text) / height

                bounding_boxes.append([xmin, ymin, xmax, ymax])
                # 每个object都会有一个名称,目标值进行one-hot编码,与预测值进行交叉熵损失
                object_name = object_tree.find('name').text
                object_onehot = self.one_hot(object_name)
                one_hots.append(object_onehot)

            # 进行物体位置和目标值的one_hot编码进行拼接
            bounding_boxes = np.asarray(bounding_boxes)
            one_hots = np.asarray(one_hots)
            image_data = np.hstack((bounding_boxes, one_hots))  # 水平拼接
            # print(image_data)
            self.data[filename] = image_data

        return None

    def one_hot(self, name):
        one_hot_vector = [0] * self.num_classes
        if name == 'clothes':
            one_hot_vector[0] = 1
        elif name == 'pants':
            one_hot_vector[1] = 1
        elif name == 'shoes':
            one_hot_vector[2] = 1
        elif name == 'watch':
            one_hot_vector[3] = 1
        elif name == 'phone':
            one_hot_vector[4] = 1
        elif name == 'audio':
            one_hot_vector[5] = 1
        elif name == 'computer':
            one_hot_vector[6] = 1
        elif name == 'books':
            one_hot_vector[7] = 1
        else:
            print('unknown label: %s' % name)
        return one_hot_vector


if __name__ == '__main__':
    xp = XmlProcess(r"D:\\Python\\PycharmProjects\\DeepLearning\\computerVision\\datasets\\commodity\\Annotations\\\\")
    xp.process_xml()
    print(xp.data)
    pickle.dump(xp.data, open("./commodity_groundtruth2.pkl", 'wb'))

以上是关于标注数据读取与存储案例:xml读取本地文件存储到pkl的主要内容,如果未能解决你的问题,请参考以下文章

c#中怎样读取xml文件中的数据,怎样动态将数据存储到xml文件中去?

图片Bitmap在本地的存储与读取 File

Web APIs BOM- 操作浏览器之综合案例

python读取,写入和更新xml文件

如何从 SQL Server 表中读取图像数据(存储 word 文档)并将其保存到本地文件夹

如何生成输入数据并将其存储到本地文件中我们如何使用kafka读取此输入生成文件的数据