python pascal voc 2012数据集

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python pascal voc 2012数据集相关的知识,希望对你有一定的参考价值。

#!/usr/bin/env python
# coding: utf-8
#
# Author:   Kazuto Nakashima
# URL:      http://kazuto1011.github.io
# Created:  2017-10-30

from __future__ import print_function

import glob
import os.path as osp
import random
from collections import defaultdict

import cv2
import h5py
import numpy as np
import scipy.io as sio
import torch
from torch.utils import data
from tqdm import tqdm

_MEAN = [104.008, 116.669, 122.675]

_VERSION = "1.1"


class Voc12(data.Dataset):
    def __init__(
        self,
        root,
        split="train",
        image_size=473,
        crop_size=473,
        scale=True,
        flip=True,
        preload=False,
    ):
        self.root = root
        self.split = split
        self.image_size = (
            image_size if isinstance(image_size, tuple) else (image_size, image_size)
        )
        self.crop_size = (
            crop_size if isinstance(crop_size, tuple) else (crop_size, crop_size)
        )
        self.scale = scale  # scale and crop
        self.flip = flip
        self.preload = preload
        self.mean = np.array(_MEAN)
        self.files = defaultdict(list)
        self.images = []
        self.labels = []
        self.ignore_label = 255

        # Load all path to images
        for split in ["train", "val"]:
            file_list = open(osp.join(self.root, 'list', split + '.txt'), 'r')
            file_list = [id_.rstrip() for id_ in file_list]
            self.files[split] = file_list

        if self.preload:
            self._preload_data()

        cv2.setNumThreads(0)

    def __len__(self):
        return len(self.files[self.split])

    def __getitem__(self, index):
        if self.preload:
            image, label = self.images[index], self.labels[index]
        else:
            image_id = self.files[self.split][index]
            image, label = self._load_data(image_id)
        image, label = self._transform(image, label)
        image = image.transpose(2, 0, 1)
        return image.astype(np.float32), label.astype(np.int64)

    def _transform(self, image, label):
        if self.scale:
            # Scaling
            scale_factor = random.uniform(0.5, 1.5)
            image = cv2.resize(
                image,
                None,
                fx=scale_factor,
                fy=scale_factor,
                interpolation=cv2.INTER_LINEAR,
            )
            label = cv2.resize(
                label,
                None,
                fx=scale_factor,
                fy=scale_factor,
                interpolation=cv2.INTER_NEAREST,
            )
            h, w = label.shape
            # Padding
            if scale_factor < 1.0:
                pad_h = max(self.image_size[0] - h, 0)
                pad_w = max(self.image_size[1] - w, 0)
                if pad_h > 0 or pad_w > 0:
                    image = cv2.copyMakeBorder(
                        image,
                        pad_h // 2,
                        pad_h - pad_h // 2,
                        pad_w // 2,
                        pad_w - pad_w // 2,
                        cv2.BORDER_CONSTANT,
                        value=(0.0, 0.0, 0.0),
                    )
                    label = cv2.copyMakeBorder(
                        label,
                        pad_h // 2,
                        pad_h - pad_h // 2,
                        pad_w // 2,
                        pad_w - pad_w // 2,
                        cv2.BORDER_CONSTANT,
                        value=(self.ignore_label,),
                    )
            # Random cropping
            h, w = label.shape
            off_h = random.randint(0, h - self.crop_size[0])
            off_w = random.randint(0, w - self.crop_size[1])
            image = image[
                off_h : off_h + self.crop_size[0], off_w : off_w + self.crop_size[1]
            ]
            label = label[
                off_h : off_h + self.crop_size[0], off_w : off_w + self.crop_size[1]
            ]
        if self.flip:
            # Random flipping
            if random.random() < 0.5:
                image = np.fliplr(image).copy()  # HWC
                label = np.fliplr(label).copy()  # HW
        return image, label

    def _load_data(self, image_id):
        image_path = osp.join(self.root, 'img', 'ori', '{}.jpg'.format(image_id))
        label_path = osp.join(self.root, 'gt', '{}.png'.format(image_id))
        # Load an image
        image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
        image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_LINEAR)
        image -= self.mean
        # Load a label map
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        #label[label==255] = self.ignore_label  # unlabeled (255 -> -1)
        label = cv2.resize(label, self.image_size, interpolation=cv2.INTER_NEAREST)
        return image, label

    def _preload_data(self):
        for image_id in tqdm(
            self.files[self.split],
            desc="Preloading...",
            leave=False,
            dynamic_ncols=True,
        ):
            image, label = self._load_data(image_id)
            self.images.append(image)
            self.labels.append(label)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Version: {}\n".format(_VERSION)
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        fmt_str += "    Split: {}\n".format(self.split)
        fmt_str += "    Root Location: {}\n".format(self.root)
        return fmt_str


if __name__ == "__main__":
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import torchvision
    from torchvision.utils import make_grid

    dataset_root = "/data1/jayzjwang/data/datasets/voc"
    kwargs = {"nrow": 10, "padding": 30}
    batch_size = 100

    dataset = Voc12(root=dataset_root, split="train")
    print(dataset)

    loader = data.DataLoader(dataset, batch_size=batch_size)

    for i, data in tqdm(
        enumerate(loader), total=np.ceil(len(dataset) / batch_size), leave=False
    ):
        imgs, labels = data

        if i == 0:
            mean = torch.tensor(_MEAN).unsqueeze(0).unsqueeze(2).unsqueeze(3)
            imgs += mean.expand_as(imgs)
            img = make_grid(imgs, **kwargs).numpy()
            img = np.transpose(img, (1, 2, 0))
            img = img[:, :, ::-1].astype(np.uint8)

            label = make_grid(
                labels[:, np.newaxis, ...], pad_value=-1, **kwargs
            ).numpy()
            label_ = np.transpose(label, (1, 2, 0))[..., 0].astype(np.float32) + 1
            label = cm.jet(label_ / 183.)[..., :3] * 255
            label *= (label_ != 0)[..., None]
            label = label.astype(np.uint8)

            img = np.hstack((img, label))
            #plt.figure(figsize=(20, 20))
            #plt.imshow(img)
            #plt.axis("off")
            #plt.tight_layout()
            #plt.savefig("./docs/data.png", bbox_inches="tight", transparent=True)
            plt.imsave("./docs/data.png", img)
            #plt.show()
            quit()

以上是关于python pascal voc 2012数据集的主要内容,如果未能解决你的问题,请参考以下文章

PASCAL-VOC2012数据集(vocdevkitVocbenchmark_release)详细介绍

Pascal voc 2012行人数据集

PASCAL VOC 2012

PASCAL VOC数据集分析

PASCAL VOC数据集分析(转)

Python 制作Pascal VOC数据集