如何在 pytorch 中处理大型数据集

Posted

技术标签:

【中文标题】如何在 pytorch 中处理大型数据集【英文标题】:How to work with large dataset in pytorch 【发布时间】:2019-07-12 05:13:00 【问题描述】:

我有一个不适合内存 (150G) 的庞大数据集,我正在寻找在 pytorch 中使用它的最佳方法。数据集由几个.npz 文件组成,每个文件有 10k 个样本。我试图建立一个Dataset

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(self.path)
        self.file_length = 
        for f in self.files:
            # Load file in as a nmap
            d = np.load(os.path.join(self.path, f), mmap_mode='r')
            self.file_length[f] = len(d['y'])

    def __len__(self):
        raise NotImplementedException()

    def __getitem__(self, idx):                
        # Find the file where idx belongs to
        count = 0
        f_key = ''
        local_idx = 0
        for k in self.file_length:
            if count < idx < count + self.file_length[k]:
                f_key = k
                local_idx = idx - count
                break
            else:
                count += self.file_length[k]
        # Open file as numpy.memmap
        d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
        # Actually fetch the data
        X = np.expand_dims(d['X'][local_idx], axis=1)
        y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
        return X, y

但实际提取样本时,需要 30 多秒。看起来整个 .npz 已打开,存储在 RAM 中并访问了正确的索引。 如何提高效率?

编辑

这似乎是对.npz文件see post的误解,但是有没有更好的方法?

解决方案建议

正如@covariantmonkey 所建议的,lmdb 可能是一个不错的选择。目前,由于问题来自.npz 文件而不是memmap,我通过将.npz 包文件拆分为几个.npy 文件来重构我的数据集。我现在可以使用与memmap 相同的逻辑,并且速度非常快(加载样本需要几毫秒)。

【问题讨论】:

【参考方案1】:

单个.npz 文件有多大?一个月前我也遇到过类似的情况。各种forum 帖子,后来谷歌搜索我去了lmdb 路线。这就是我所做的

    将大型数据集分成足够小的文件,以便我可以放入 gpu 中 — 每个文件本质上都是我的 minibatch。在这个阶段我没有优化加载时间只是内存。 用key = filenamedata = np.savez_compressed(stff)创建一个lmdb索引

lmdb 为您处理 mmap,并且加载速度非常快。

问候, 一个

PS:savez_compessed 需要一个字节对象,因此您可以执行类似的操作

output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb

【讨论】:

以上是关于如何在 pytorch 中处理大型数据集的主要内容,如果未能解决你的问题,请参考以下文章

如何在 python 中处理大型图像数据集?

小白学习PyTorch教程十基于大型电影评论数据集训练第一个LSTM模型

深度学习之Pytorch——如何使用张量处理文本数据集(语料库数据集)

在 R 中处理大型数据集

如何使用批处理为大型数据集拟合 Keras ImageDataGenerator

如何在pytorch中进行并行处理