ccc-pytorch-宝可梦自定义数据集实战-加载数据部分

Posted 扔出去的回旋镖

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ccc-pytorch-宝可梦自定义数据集实战-加载数据部分相关的知识,希望对你有一定的参考价值。

文章目录

第一步:构建路径与种类的映射关系

import os
from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label=
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():
    db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')

if __name__ == '__main__':
    main()

第二步:载入所有的宝可梦图像

import os,glob

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label=
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.load_csv('images.csv')
    def load_csv(self,filename):
        images = []
        for name in self.name2label.keys():
            images +=glob.glob(os.path.join(self.root,name,'*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
        #1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
        print(len(images),images)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')

if __name__ == '__main__':
    main()

第三步:打撒顺序并通过路径名提取映射关系构建映射文件

import csv
import os,glob
import random

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label=
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

            images,labels = [],[]
            with open(os.path.join(self.root,filename)) as f:
                reader = csv.reader(f)
                for row in reader:
                    img , label = row
                    label = int (label)
                    images.append(img)
                    labels.append(label)
            assert  len(images) == len(labels)
            return images,labels

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')

if __name__ == '__main__':
    main()


第四步:完善选取、获取图片信息功能并可视化

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label=
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #img:D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)),
            transforms.ToTensor()
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    viz = visdom.Visdom()

    db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(x,win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

第五步:对数据进行预处理

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label=
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label]<

以上是关于ccc-pytorch-宝可梦自定义数据集实战-加载数据部分的主要内容,如果未能解决你的问题,请参考以下文章

童年的回忆:精灵宝可梦

Comet OJ 1023 [欢乐赛]第001话 宝可梦,就决定是你了!题解

打造属于自己的 宝可梦终端

分类---概率生成模型

ZJNU 2135 - 小智的宝可梦

把宝可梦搬到终端后,摸鱼也不会被老板发现了,收集对战玩法一应俱全|开源...