Pytorch DataLoader 不返回批处理数据

Posted

技术标签:

【中文标题】Pytorch DataLoader 不返回批处理数据【英文标题】:Pytorch DataLoader doesn't return batched data 【发布时间】:2022-01-23 07:02:06 【问题描述】:

我的数据集由从原始图像(人脸补丁和随机人脸补丁之外)获得的图像补丁组成。补丁存储在一个文件夹中,该文件夹具有补丁源自的原始图像的名称。我创建了自己的 DataSet 和 DataLoader,但是当我遍历数据集时,数据不会批量返回。大小为 1 的批次应该包括一个补丁元组数组和一个标签,因此随着批次大小的增加,我们应该得到一个带有标签的元组数组。但是无论批量大小,DataLoader 都只返回一个元组数组。

我的数据集:

  import os
  import cv2 as cv
  import PIL.Image as Image
  import torchvision.transforms as Transforms
  from torch.utils.data import dataset    

  class PatchDataset(dataset.Dataset):
    def __init__(self, img_folder, n_patches):
      self.img_folder = img_folder
      self.n_patches = n_patches
      self.img_names = sorted(os.listdir(img_folder))

      self.transform = Transforms.Compose([
        Transforms.Resize((50, 50)),
        Transforms.ToTensor()
      ])
    
    def __len__(self):
      return len(self.img_names)
    
    def __getitem__(self, idx):
      img_name = self.img_names[idx]
      patch_dir = os.path.join(self.img_folder, img_name)
      patches = []
    
      for i in range(self.n_patches):
        face_patch = cv.imread(os.path.join(patch_dir, f'str(i)_face.png'))
        face_patch = cv.cvtColor(face_patch, cv.COLOR_BGR2RGB)
        face_patch = Image.fromarray(face_patch)
        face_patch = self.transform(face_patch)
    
        patch = cv.imread(os.path.join(patch_dir, f'str(i)_patch.png'))
        patch = cv.cvtColor(patch, cv.COLOR_BGR2RGB)
        patch = Image.fromarray(patch)
        patch = self.transform(patch)

        patches.append((face_patch, patch))
    
      return patches, int(img_name.split('-')[0])

然后我就这样使用它:

X = PatchDataset(PATCHES_DIR, 9)
train_dl = dataloader.DataLoader(
    X,
    batch_size=10,
    drop_last=True
)

for batch_X, batch_Y in train_dl:
  print(len(batch_X))
  print(len(batch_Y))

在此提供的情况下,批量大小为 10,因此打印 batch_Y 会返回正确的数字 (10)。但是batch_X 的打印返回 9,这是补丁对的数量 - 从数据集中只返回一个样本,而不是 10 个样本的批次,每个样本的长度为 9。

【问题讨论】:

通常你必须从torch.utils.data.Dataset继承,但你使用dataset.Dataset我不知道。或者你做了:import torch.utils.data as dataset?如果这不是错误,请提供您使用数据加载器的代码:) @TheodorPeifer 是的,我是这样导入的,DataLoader 也是如此。我添加了您要求的示例并提供了更多信息。 【参考方案1】:

您应该在__get_item__ 函数调用中返回更高一维的tensor,而不是张量的list。你可以使用torch.stack(patches)

def __getitem__(self, idx):
   img_name = self.img_names[idx]
   patch_dir = os.path.join(self.img_folder, img_name)
   patches = []

   for i in range(self.n_patches):
       face_patch = cv.imread(os.path.join(patch_dir, f'str(i)_face.png'))
       face_patch = cv.cvtColor(face_patch, cv.COLOR_BGR2RGB)
       face_patch = Image.fromarray(face_patch)
       face_patch = self.transform(face_patch)

       patch = cv.imread(os.path.join(patch_dir, f'str(i)_patch.png'))
       patch = cv.cvtColor(patch, cv.COLOR_BGR2RGB)
       patch = Image.fromarray(patch)
       patch = self.transform(patch)

       patches.append((face_patch, patch))

   return torch.stack(patches), int(img_name.split('-')[0])

【讨论】:

以上是关于Pytorch DataLoader 不返回批处理数据的主要内容,如果未能解决你的问题,请参考以下文章

不使用多处理但在使用 PyTorch DataLoader 时在 google colab 上出现 CUDA 错误

Pytorch的Dataset与Dataloader之间的关系

Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

如何在 Dataloader 类之外的 pytorch 中创建数据预处理管道?

PyTorch自定义数据集处理/dataset/DataLoader等

PyTorch自定义数据集处理/dataset/DataLoader等