Pytorch:“KeyError:在 DataLoader 工作进程 0 中捕获 KeyError。”

Posted

技术标签:

【中文标题】Pytorch:“KeyError:在 DataLoader 工作进程 0 中捕获 KeyError。”【英文标题】:Pytorch: "KeyError: Caught KeyError in DataLoader worker process 0." 【发布时间】:2021-11-22 08:29:59 【问题描述】:

问题描述:

我正在尝试使用 Pytorch 自定义数据集加载图像数据。我做了一点深入研究,发现我的图像集包含 2 种形状 (512,512,3) 和 (1024,1024) 。我的假设是,由于上述原因,它会抛出以下错误。

注意:该代码能够读取一些图像,但是对于其中的一些图像,它会抛出以下错误消息。这就是对图像数据做一点 EDA 的原因,发现数据集中有 2 种不同形状的图像。

第一季度。如何对此类图像数据进行预处理以进行训练?

第二季度。我可能会看到以下错误消息还有其他原因吗?

错误信息:

KeyError                                  Traceback (most recent call last)
<ipython-input-163-aa3385de8026> in <module>
----> 1 train_features, train_labels = next(iter(train_dataloader))
  2 print(f"Feature batch shape: train_features.size()")
  3 print(f"Labels batch shape: train_labels.size()")
  4 img = train_features[0].squeeze()
  5 label = train_labels[0]

 ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils  /data/dataloader.py in __next__(self)
519             if self._sampler_iter is None:
520                 self._reset()
521             data = self._next_data()
522             self._num_yielded += 1
523             if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1201             else:
1202                 del self._task_info[idx]
1203                 return self._process_data(data)
1204 
1205     def _try_put_index(self):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1227         self._try_put_index()
1228         if isinstance(data, ExceptionWrapper):
1229             data.reraise()
1230         return data
1231 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
423             # have message field
424             raise self.exc_type(message=msg)
425         raise self.exc_type(msg)
426 
427 

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas  /core/indexes/base.py", line 2898, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 101, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1032, in    pandas._libs.hashtable.Int64HashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1039, in   pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 16481

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-161-f38b78d77dcb>", line 19, in __getitem__
img_path =os.path.join(self.img_dir,self.image_ids[idx])
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 882, in __getitem__
return self._get_value(key)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 990, in _get_value
loc = self.index.get_loc(label)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2900, in get_loc
raise KeyError(key) from err
KeyError: 16481

代码:

from torchvision.io import read_image
import torch
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
     # init
    def __init__(self,dataset,transforms=None,target_transforms=None):
        #self.train_data = pd.read_csv("Data/train_data.csv")
        self.image_ids = dataset.image_id
        self.image_labels = dataset.label
        self.img_dir = 'Data/images'
        self.transforms = transforms
        self.target_transforms = target_transforms
# len
    def __len__(self):
        return len(self.image_ids)
# getitem
    def __getitem__(self,idx):
        # image path
        img_path =os.path.join(self.img_dir,self.image_ids[idx])
        # image
        image = read_image(img_path)
        label = self.image_labels[idx]
    # transform image
        if self.transforms:
             image = self.transforms(image)
    # transform target
        if self.target_transforms:
             label = self.target_transforms(label)
    return image, label

代码:train_data是csv文件的pandas对象,包含image id,labelsl信息。

  from sklearn.model_selection import train_test_split
  X_train, X_test = train_test_split(train_data, test_size=0.1, random_state=42)
  train_df = CustomImageDataset(X_train)
  train_dataloader = torch.utils.data.DataLoader(
        train_df,
        batch_size=64,
        num_workers=1,
        shuffle=True,
    )

【问题讨论】:

【参考方案1】:

发现代码有问题。

Pytorch Custom Dataloader function "getitem" 使用 idx 来检索数据,我的猜测是,它从 len 函数中知道 idx 的范围,例如:0,直到 len (数据集中的行)。

就我而言,我已经有一个以 idx 作为列之一的熊猫数据集 (train_data)。当我将其随机拆分为 X_train 和 X_test 时,很少有数据行与 idx 一起移动到 X_test。

现在,当我将 X_train 发送到自定义数据加载器时,它正在尝试使用 idx 获取行的 image_id,而该 idx 恰好位于 X_test 数据集中。这会导致错误为 keyerror: 16481,即 X_train 数据集中不存在 idx=16481 的行。它在拆分期间被移动到 X_test。

呼……

【讨论】:

你拯救了我的一天。我还根据数据框的索引进行索引,而 df_train 自拆分以来显然缺乏一些索引。我只需要在创建数据集时重置索引。【参考方案2】:

我在 PyTorch 中微调基于 DistilBertModel 转换器的模型并更换其头部时遇到了同样的错误。

我忘记在 train_test_split 之后重置 train_dataframe 和 test_dataframe 的索引,导致我的CustomDataset索引不正确。

【讨论】:

以上是关于Pytorch:“KeyError:在 DataLoader 工作进程 0 中捕获 KeyError。”的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 中的常用矩阵操作

Pytorch Note1 Pytorch介绍

pytorch_geometric + MinkowskiEngine

1. PyTorch是什么?

1. PyTorch是什么?

对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码