如何将pickle文件中的数据集加载到PyTorch中?

Posted

技术标签:

【中文标题】如何将pickle文件中的数据集加载到PyTorch中?【英文标题】:How to load dataset from pickle files into PyTorch? 【发布时间】:2020-09-27 07:25:15 【问题描述】:

我有 X_train(inputs) 和 Y_train(labels) 在单独的 pickle 文件中以整数矩阵的形式。现在,我需要加载它们并使用 PyTorch 进行训练。我尝试了torch.utils.data.DataLoadertorchvision.datasets.DatasetFolder,但没有任何效果,否则我可能会在某个地方出错。请提出一个正确的方法。

【问题讨论】:

【参考方案1】:

您确实应该通过一些示例清楚地描述您的问题。无论如何,据我了解,您正在寻找这样的东西。

import pickle
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader


class YourDataset(Dataset):

    def __init__(self, X_Train, Y_Train, transform=None):
        self.X_Train = X_Train
        self.Y_Train = Y_Train
        self.transform = transform

    def __len__(self):
        return len(self.X_Train)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = self.X_Train[idx]
        y = self.Y_Train[idx]

        if self.transform:
            x = self.transform(x)
            y = self.transform(y)

        return x, y


file = open('FILENAME_X_train', 'rb')
X_train = pickle.load(file)
file.close()

file = open('FILENAME_Y_train', 'rb')
Y_train = pickle.load(file)
file.close()

your_dataset = YourDataset(X_train, Y_train, transform=transforms.Compose([transforms.ToTensor()]))

your_data_loader = DataLoader(your_dataset, batch_size=8, shuffle=True, num_workers=0)

请注意,我没有测试过代码,但我认为它给出了大致的想法。希望对您有所帮助。

【讨论】:

以上是关于如何将pickle文件中的数据集加载到PyTorch中?的主要内容,如果未能解决你的问题,请参考以下文章

Pandas将dataframe保存为pickle文件并加载保存后的pickle文件查看dataframe数据实战

如何将多个 xls 文件中的 xls 数据加载到 hive 中?

如何使用列中的np数组条目创建panda / pickle数据集,以便我可以有效地绘制它们?

利用Python进行数据分析 第6章 数据加载存储与文件格式

如何将 CSV 文件中的数据加载到 numpy 数组中[重复]

如何在 C# 中将数据集加载到 libsvm 中