ccc-pytorch-宝可梦自定义数据集实战-加载数据部分
Posted 扔出去的回旋镖
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ccc-pytorch-宝可梦自定义数据集实战-加载数据部分相关的知识,希望对你有一定的参考价值。
文章目录
第一步:构建路径与种类的映射关系
import os
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label=
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')
if __name__ == '__main__':
main()
第二步:载入所有的宝可梦图像
import os,glob
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label=
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.load_csv('images.csv')
def load_csv(self,filename):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
print(len(images),images)
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')
if __name__ == '__main__':
main()
第三步:打撒顺序并通过路径名提取映射关系构建映射文件
import csv
import os,glob
import random
from torch.utils.data import Dataset
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label=
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
pass
def __getitem__(self, idx):
pass
def main():
db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')
if __name__ == '__main__':
main()
第四步:完善选取、获取图片信息功能并可视化
import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label=
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
if model == 'train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif model == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
else :
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file',filename)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img , label = row
label = int (label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
#img:D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png
img , label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((self.resize,self.resize)),
transforms.ToTensor()
])
img = tf(img)
label = torch.tensor(label)
return img,label
def main():
import visdom
viz = visdom.Visdom()
db =Pokeman('D:\\pythonProject\\pythonProject39\\pokeman',224,'train')
# 得到迭代器第一个样本
x,y = next(iter(db))
print('sample:',x.shape,y.shape)
viz.images(x,win='sample_x',opts=dict(title='sample_x'))
if __name__ == '__main__':
main()
第五步:对数据进行预处理
import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class Pokeman(Dataset):
def __init__(self,root,resize,model):
super(Pokeman,self).__init__()
self.root=root
self.resize=resize
self.name2label=
print(root)
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images.csv')
if model == 'train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif model == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
else :
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images +=glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1167,'D:\\\\pythonProject\\\\pythonProject39\\\\pokeman\\\\bulbasaur\\\\00000000.png'
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images :
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label]<以上是关于ccc-pytorch-宝可梦自定义数据集实战-加载数据部分的主要内容,如果未能解决你的问题,请参考以下文章