如何将pickle文件中的数据集加载到PyTorch中?
Posted
技术标签:
【中文标题】如何将pickle文件中的数据集加载到PyTorch中?【英文标题】:How to load dataset from pickle files into PyTorch? 【发布时间】:2020-09-27 07:25:15 【问题描述】:我有 X_train(inputs) 和 Y_train(labels) 在单独的 pickle 文件中以整数矩阵的形式。现在,我需要加载它们并使用 PyTorch 进行训练。我尝试了torch.utils.data.DataLoader
和torchvision.datasets.DatasetFolder
,但没有任何效果,否则我可能会在某个地方出错。请提出一个正确的方法。
【问题讨论】:
【参考方案1】:您确实应该通过一些示例清楚地描述您的问题。无论如何,据我了解,您正在寻找这样的东西。
import pickle
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
class YourDataset(Dataset):
def __init__(self, X_Train, Y_Train, transform=None):
self.X_Train = X_Train
self.Y_Train = Y_Train
self.transform = transform
def __len__(self):
return len(self.X_Train)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
x = self.X_Train[idx]
y = self.Y_Train[idx]
if self.transform:
x = self.transform(x)
y = self.transform(y)
return x, y
file = open('FILENAME_X_train', 'rb')
X_train = pickle.load(file)
file.close()
file = open('FILENAME_Y_train', 'rb')
Y_train = pickle.load(file)
file.close()
your_dataset = YourDataset(X_train, Y_train, transform=transforms.Compose([transforms.ToTensor()]))
your_data_loader = DataLoader(your_dataset, batch_size=8, shuffle=True, num_workers=0)
请注意,我没有测试过代码,但我认为它给出了大致的想法。希望对您有所帮助。
【讨论】:
以上是关于如何将pickle文件中的数据集加载到PyTorch中?的主要内容,如果未能解决你的问题,请参考以下文章
Pandas将dataframe保存为pickle文件并加载保存后的pickle文件查看dataframe数据实战
如何将多个 xls 文件中的 xls 数据加载到 hive 中?
如何使用列中的np数组条目创建panda / pickle数据集,以便我可以有效地绘制它们?
利用Python进行数据分析 第6章 数据加载存储与文件格式