如何在 pytorch 中处理大型数据集
Posted
技术标签:
【中文标题】如何在 pytorch 中处理大型数据集【英文标题】:How to work with large dataset in pytorch 【发布时间】:2019-07-12 05:13:00 【问题描述】:我有一个不适合内存 (150G) 的庞大数据集,我正在寻找在 pytorch 中使用它的最佳方法。数据集由几个.npz
文件组成,每个文件有 10k 个样本。我试图建立一个Dataset
类
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length =
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
但实际提取样本时,需要 30 多秒。看起来整个 .npz
已打开,存储在 RAM 中并访问了正确的索引。
如何提高效率?
编辑
这似乎是对.npz
文件see post的误解,但是有没有更好的方法?
解决方案建议
正如@covariantmonkey 所建议的,lmdb 可能是一个不错的选择。目前,由于问题来自.npz
文件而不是memmap
,我通过将.npz
包文件拆分为几个.npy
文件来重构我的数据集。我现在可以使用与memmap
相同的逻辑,并且速度非常快(加载样本需要几毫秒)。
【问题讨论】:
【参考方案1】:单个.npz
文件有多大?一个月前我也遇到过类似的情况。各种forum 帖子,后来谷歌搜索我去了lmdb 路线。这就是我所做的
-
将大型数据集分成足够小的文件,以便我可以放入 gpu 中 — 每个文件本质上都是我的 minibatch。在这个阶段我没有优化加载时间只是内存。
用
key = filename
和data = np.savez_compressed(stff)
创建一个lmdb索引
lmdb
为您处理 mmap,并且加载速度非常快。
问候, 一个
PS:savez_compessed
需要一个字节对象,因此您可以执行类似的操作
output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb
【讨论】:
以上是关于如何在 pytorch 中处理大型数据集的主要内容,如果未能解决你的问题,请参考以下文章
小白学习PyTorch教程十基于大型电影评论数据集训练第一个LSTM模型
深度学习之Pytorch——如何使用张量处理文本数据集(语料库数据集)