PyTorch:数据读取2 - Dataloader

Posted -柚子皮-

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch:数据读取2 - Dataloader相关的知识,希望对你有一定的参考价值。

-柚子皮-

nlp中的dataloader的使用

torch.utils.data.DataLoader中的参数:

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable*, *optional) – merges a list of samples to form a mini-batch.
  • pin_memory (bool, optional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

返回值

      返回值是一个实现了__iter__的对象,可以使用for循环进行迭代,或者转换成迭代器取第一条batch数据查看。

for循环进行迭代时返回的每条数据就是(batch_size,*)大小的。

常用操作

使用示例

self.data_loader = torch.utils.data.DataLoader(
            dataset=self.dataset, collate_fn=self.collate_fn,
            batch_size=batch_size, shuffle=if_shuffle, num_workers=args.num_workers)

batch数目

batch_num = len(train_dataloader)

获取dataset中的第一条数据

train_data_loader.dataset.__getitem__(0)

获取dataloader中batch中的第一条数据

def get_one_data(item_dict, i):
    return k: v[i] for k, v in item_dict.items()

print(get_one_data(next(iter(train_data_loader)), 1))

或者

    for index, item_dict in enumerate(train_data_loader):
        print(get_one_data(item_dict, 1))
        exit()

 

自定义dataloader

Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作(比如每个batch中实际lengths,padding,cuda之类的)。

自定义collate_fn

因为dataloader是有batch_size参数的,我们可以通过自定义collate_fn=myfunction来设计数据收集的方式,意思是已经通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数。

示例1:通过collate_fn进行解包

def myfunction(data):
    A,B,path,hop=zip(*data)
    print('A:',A," B:",B," path:",path," hop:",hop)
    raise Exception('utils collate_fun 147')
    return A,B,path,hop

 

for index,item in enumerate(dataloaders['train'])
    A,B,path.hop=item

Note: 需要在外面对dataloaders进行for调用,后再断点或者exit(),否则不会真正执行collate_fn,这样就不会print了。
示例2:nlp任务中,经常在collate_fn指定的函数里面做padding,将同一个batch中不一样长的句子padding成一样长。

def myfunction(data):
    src, tgt, original_src, original_tgt = zip(*data)

    src_len = [len(s) for s in src]
    src_pad = torch.zeros(len(src), max(src_len)).long()
    for i, s in enumerate(src):
        end = src_len[i]
        src_pad[i, :end] = torch.LongTensor(s[end-1::-1])

    tgt_len = [len(s) for s in tgt]
    tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
    for i, s in enumerate(tgt):
        end = tgt_len[i]
        tgt_pad[i, :end] = torch.LongTensor(s)[:end]

    return src_pad, tgt_pad, \\
           torch.LongTensor(src_len), torch.LongTensor(tgt_len), \\
           original_src, original_tgt

 

一些问题

[为什么pytorch DataLoader在numpy数组和列表上的行为有所不同?]

1 import问题

使用torch.utils.data.DataLoader时,pycharm中无法直接点击进入代码。

2 num_workers设置过大问题

num_workers如果设置过大,资源不够,会出错:Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)
[运行tensorflow程序bug:Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)]

[python模块导入及属性:import]

[https://github.com/pytorch/pytorch/issues/41794]

from: -柚子皮-

ref: [https://www.jianshu.com/p/8ea7fba72673]

 

以上是关于PyTorch:数据读取2 - Dataloader的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch:数据读取2 - Dataloader

pytorch训练集的读取

pytorch训练集的读取

「深度学习一遍过」必修3:Pytorch数据读取——使用Dataloader读取Dataset

PyTorch数据保存和读取的学习笔记

PyTorch:数据读取机制DataLoader