pytorch关于数据载入的代码

Posted Mario cai

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch关于数据载入的代码相关的知识,希望对你有一定的参考价值。

 第一步:读取文件

class PET_dataset(Dataset):
    def __init__(self,path):
        self.image_path_1 = os.path.join(path,'train1')
        self.image_path_0 = os.path.join(path,'train0')
        self.test_path_1 = os.path.join(path,'test1')
        self.test_path_0 = os.path.join(path,'test0')
        self.img_path_1 = sorted(os.listdir(self.image_path_1))
        self.img_path_0 = sorted(os.listdir(self.image_path_0))
        self.test_path_1 = sorted(os.listdir(self.test_path_1))
        self.test_path_0 = sorted(os.listdir(self.test_path_0))

定义__init__后,执行实例化的过程须变成PET_dataset(path)新建的实例本身,连带其中的参数,会一并传给__init__函数自动并执行它。所以__init__函数的参数列表会在开头多出一项,它永远指代新建的那个实例对象,Python语法要求这个参数必须要有,而名称随意,习惯上就命为self

 第二步:获得文件长度

    def __len__(self):
        return len(self.img_path_0)

第三步:读取文件内数据

    def __getitem__(self, item):
        train1_list, train0_list = self.img_path_1[item], self.img_path_0[item]
        #获得文件内的item
        train1_dcm = os.path.join(self.image_path_1, train1_list)#,allow_pickle=True)
        # img_array=read_data(img_dcm)
        #读取item
        train1_array = nib.load(train1_dcm)              #根据哪个库提取数据
        train1_array = np.array(train1_array.dataobj)    #数据转成numpy

        #label 如上
        train0_dcm = os.path.join(self.image_path_0, train0_list)#,allow_pickle=True)
        # lab_array=read_data(lab_dcm)
        train0_array=nib.load(train0_dcm)
        train0_array = np.array(train0_array.dataobj)


        #将numpy转化为tensor
        train1 = torch.FloatTensor(train1_array)
        train0 = torch.FloatTensor(train0_array)
        #在最前面加一个维度,CNN常用,conv2d的输入必须是四维    
        train1 = train1.unsqueeze(0)
        train0 = train0.unsqueeze(0)
        return  (train1,train0)

第四步:封装起来

def get_load(dir, batch_size,shuffle,num_workers):
    dataset_ = PET_dataset(dir)
    data_loader = DataLoader(dataset=dataset_,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers)
    return data_loader

验证

data_loader = get_load(path,1,True,0)
def train():
    total_iters = 0
    for iter_, (x, y ) in enumerate(data_loader):
        total_iters += 1
        print(x.shape)
        print(y.shape)

train()

以上是关于pytorch关于数据载入的代码的主要内容,如果未能解决你的问题,请参考以下文章

利用pytorch的载入训练npy类型数据代码

利用pytorch的载入训练npy类型数据代码

利用pytorch的载入训练npy类型数据代码

Pytorch学习

用Pytorch训练分类模型

FasterRCNN代码解读