pytorch关于数据载入的代码
Posted Mario cai
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch关于数据载入的代码相关的知识,希望对你有一定的参考价值。
第一步:读取文件
class PET_dataset(Dataset):
def __init__(self,path):
self.image_path_1 = os.path.join(path,'train1')
self.image_path_0 = os.path.join(path,'train0')
self.test_path_1 = os.path.join(path,'test1')
self.test_path_0 = os.path.join(path,'test0')
self.img_path_1 = sorted(os.listdir(self.image_path_1))
self.img_path_0 = sorted(os.listdir(self.image_path_0))
self.test_path_1 = sorted(os.listdir(self.test_path_1))
self.test_path_0 = sorted(os.listdir(self.test_path_0))
定义__init__
后,执行实例化的过程须变成PET_dataset(path)
,新建的实例本身,连带其中的参数,会一并传给__init__
函数自动并执行它。所以__init__
函数的参数列表会在开头多出一项,它永远指代新建的那个实例对象,Python语法要求这个参数必须要有,而名称随意,习惯上就命为self
。
第二步:获得文件长度
def __len__(self):
return len(self.img_path_0)
第三步:读取文件内数据
def __getitem__(self, item):
train1_list, train0_list = self.img_path_1[item], self.img_path_0[item]
#获得文件内的item
train1_dcm = os.path.join(self.image_path_1, train1_list)#,allow_pickle=True)
# img_array=read_data(img_dcm)
#读取item
train1_array = nib.load(train1_dcm) #根据哪个库提取数据
train1_array = np.array(train1_array.dataobj) #数据转成numpy
#label 如上
train0_dcm = os.path.join(self.image_path_0, train0_list)#,allow_pickle=True)
# lab_array=read_data(lab_dcm)
train0_array=nib.load(train0_dcm)
train0_array = np.array(train0_array.dataobj)
#将numpy转化为tensor
train1 = torch.FloatTensor(train1_array)
train0 = torch.FloatTensor(train0_array)
#在最前面加一个维度,CNN常用,conv2d的输入必须是四维
train1 = train1.unsqueeze(0)
train0 = train0.unsqueeze(0)
return (train1,train0)
第四步:封装起来
def get_load(dir, batch_size,shuffle,num_workers):
dataset_ = PET_dataset(dir)
data_loader = DataLoader(dataset=dataset_,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers)
return data_loader
验证
data_loader = get_load(path,1,True,0)
def train():
total_iters = 0
for iter_, (x, y ) in enumerate(data_loader):
total_iters += 1
print(x.shape)
print(y.shape)
train()
以上是关于pytorch关于数据载入的代码的主要内容,如果未能解决你的问题,请参考以下文章