使用数据加载器对异常进行批处理

Posted

技术标签:

【中文标题】使用数据加载器对异常进行批处理【英文标题】:Batching irregularities with data loader 【发布时间】:2021-12-08 10:41:20 【问题描述】:

我在 .txt 文件中有一些数据和一个由两行组成的实例,两行都有 100 个元素。第一行定义问题,第二行定义解决方案。尽管这不是一个好主意,但我尝试在数据中使用监督设置。但是,我面临批处理问题。我已经为数据加载器和完成这项工作的主 for 循环添加了代码。

我遇到的问题是,如果 我将 batch_size 设置为 5 并且 preds 数组具有正确的形式。但是,labels 数组多了一个维度,而不是 5 个整数,它有 5 个完整的问题解决方案。

我认为问题出在数据加载器中,但无法解决。我对这个概念有点陌生,我一直试图找到这个超过一个星期,但到目前为止还没有解决。

数据加载器:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np
from torch.utils.data import Dataset

class load_dataset(Dataset):
    def __init__(self, data_file='data.txt', transform=None):
        super().__init__()
        data = np.loadtxt(data_file)
        data = torch.Tensor(data)
        self.data = data[::2]
        self.targets = data[1::2]

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        adj, target = self.data[index], self.targets[index]
        return adj, target

主循环:

for inputs, labels in loaders["train"]:
    inputs, labels = inputs.view([batch_size, 100]), labels.data
    scores = mps(inputs)
    _, preds = torch.max(scores, 1)
    print("preds: ")
    print(preds)
    print("labels: ")
    print(labels)

输出:

preds:
tensor([0, 0, 0, 0, 0])
labels:
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

【问题讨论】:

【参考方案1】:

您尚未展示如何定义数据加载器,但假设您使用 torch.utils.data.DataLoader 包装 load_dataset 并设置 batch_size=5

如果您将批处理大小设置为5,那么您将在一个批处理中拥有 5 个“问题”和相应的 5 个“解决方案”。每个有 100 个组件。这意味着inputslabels 将是两个形状为(batch_size=5, 100) 的张量。

【讨论】:

首先感谢@Ivan,您的假设是正确的。但是,批处理不就是一开始的事情吗?我想我会拆分 1 个数据实例,甚至是它的片段。我的输出维度与输入(它的 100)相同。 也许我遗漏了一些东西,但您解释说 问题 可以由 100 个组件来描述,对我来说,这 100 个特征向量对应于您的模型输入,对吗?还是一个问题对应多个(100 个不同的)输出?

以上是关于使用数据加载器对异常进行批处理的主要内容,如果未能解决你的问题,请参考以下文章

加载配置信息

再次加载时存储的 Tfidf-Vectorizer ValueError

未处理的异常: 未能加载文件或程序集

vs2013外接程序”VMDebugger”加载异常处理

如何在数据框中指定缺失值

ThreeJS-纹理加载进度(十四)