PyTorch 数据集:将整个数据集转换为 NumPy

Posted

技术标签:

【中文标题】PyTorch 数据集:将整个数据集转换为 NumPy【英文标题】:PyTorch Datasets: Converting entire Dataset to NumPy 【发布时间】:2019-07-20 17:06:42 【问题描述】:

我正在尝试将 Torchvision MNIST 训练和测试数据集转换为 NumPy 数组,但找不到实际执行转换的文档。

我的目标是获取整个数据集并将其转换为单个 NumPy 数组,最好不要遍历整个数据集。

我查看了How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?,但它并没有解决我的问题。

所以我的问题是,利用torch.utils.data.DataLoader,我将如何将数据集(训练/测试)转换为两个 NumPy 数组,以便所有示例都存在?

注意:我暂时将批量大小保留为默认值 1;我可以将它设置为 60,000 用于训练,10,000 用于测试,但我不想使用那种幻数。

谢谢。

【问题讨论】:

【参考方案1】:

如果我理解正确,您希望获得 MNIST 图像的整个训练数据集(总共 60000 张图像,每个图像大小为 1x28x28 数组,颜色通道为 1)作为大小为 (60000, 1, 28) 的 numpy 数组, 28)?

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform to normalized Tensors 
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)


train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()

这是结果:

>>> train_dataset_array

array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       ...,


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]]], dtype=float32)

【讨论】:

我假设我可以使用next(iter(train_loader))[1].numpy() 来获取标签? 是的,它给出了相应的标签。【参考方案2】:

此任务无需使用torch.utils.data.DataLoader

from torchvision import datasets, transforms

train_set = datasets.MNIST('./data', train=True, download=True)
test_set = datasets.MNIST('./data', train=False, download=True)

train_set_array = train_set.data.numpy()
test_set_array = test_set.data.numpy()

请注意,在这种情况下,目标被排除在外。

【讨论】:

这很有帮助,但它也给你标签吗? 没有实际理由将标签包含在同一数据矩阵中。您可以获取另一个运行目标的 NumPy 矩阵train_set_array_targets = train_set.targets.numpy() 确实如此。我只是错过了Dataset 对象的datatargets 属性——由于某种原因,我似乎在文档中找不到。 这是因为datatargets 不是Dataset 类的属性,而是MNIST 类的属性,该类是Dataset 的子类。我会直接查看torchvision 包的MNIST 的文档,但我认为没有任何东西可以回答这里提出的问题。 这不包括转换,但接受的答案是。

以上是关于PyTorch 数据集:将整个数据集转换为 NumPy的主要内容,如果未能解决你的问题,请参考以下文章

如何创建图神经网络数据集? (pytorch 几何)

如何修复数据集以返回所需的输出(pytorch)

R方法通过将整个数据集向上移动一个小时/向下移动一个小时半年来将标准转换为夏令时?

深度学习之Pytorch——如何使用张量处理文本数据集(语料库数据集)

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext