RuntimeError: each element in list of batch should be of equal size

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RuntimeError: each element in list of batch should be of equal size相关的知识,希望对你有一定的参考价值。

RuntimeError: each element in list of batch should be of equal size

1. 示例代码

"""
完成数据集的准备
"""
from torch.utils.data import DataLoader, Dataset
import os
import re

# 分词
def tokenlize(content):
    content = re.sub('<.*?>', ' ', content, flags=re.S)
    filters = ['!', '"', '#', '$', '%', '&', '\\(', '\\)', '\\*', '\\+', ',', '-', '\\.', '/', ':', ';', '<', '=', '>', '\\?',
               '@', '\\[', '\\\\', '\\]', '^', '_', '`', '\\{', '\\|', '\\}', '~', '\\t', '\\n', '\\x97', '\\x96', '”', '“', ]

    content = re.sub('|'.join(filters), ' ', content)
    tokens = [i.strip().lower() for i in content.split()]

    return tokens


# 准备dataset
class ImdbDataset(Dataset):
    def __init__(self, train=True):
        self.train_data_path = r'E:\\Python资料\\视频\\Py5.0\\00.8-12课件资料V5.0\\阶段9-人工智能NLP项目\\第四天\\代码\\data\\aclImdb_v1\\aclImdb\\train'
        self.test_data_path = r'E:\\Python资料\\视频\\Py5.0\\00.8-12课件资料V5.0\\阶段9-人工智能NLP项目\\第四天\\代码\\data\\aclImdb_v1\\aclImdb\\test'
        data_path = self.train_data_path if train else self.test_data_path

        # 把所有的文件名放入列表
        temp_data_path = [os.path.join(data_path, 'pos'), os.path.join(data_path, 'neg')]
        self.total_file_path = []  # 所有评论文件的path
        for path in temp_data_path:
            file_name_list = os.listdir(path)
            file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith('.txt')]
            self.total_file_path.extend(file_path_list)

    def __getitem__(self, idx):
        file_path = self.total_file_path[idx]
        # 获取了label
        label_str = file_path.split('\\\\')[-2]
        label = 0 if label_str == 'neg' else 1
        # 获取内容
        # 分词
        tokens = tokenlize(open(file_path).read())
        return tokens, label

    def __len__(self):
        return len(self.total_file_path)


# 获取数据集加载器
def get_dataloader(train=True):
    imdb_dataset = ImdbDataset(train)
    print(imdb_dataset[1])
    data_loader = DataLoader(imdb_dataset, batch_size=2, shuffle=True)
    return data_loader


# 观察输出结果
if __name__ == '__main__':
    for idx, (input, target) in enumerate(get_dataloader()):
        print('idx', idx)
        print('input', input)
        print('target', target)
        break

2. 运行结果

3. 报错原因

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)

如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:

4. batch_size=2

但是,如果想让batch_size=2时,那该如何解决呢?

解决方法:

出现问题的原因在于Dataloader中的参数collate_fn

collate_fn的默认值为torch自定义的default_collate, collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错。

解决思路:

  1. 考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误.
  2. 考虑自定义一个collate_fn,观察结果.

这里使用方式2,自定义一个collate_fn,然后观察结果:

def collate_fn(batch):
    """
    对batch数据进行处理
    :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
    :return: 元组
    """
    reviews,labels = zip(*batch)
    reviews = torch.LongTensor([config.ws.transform(i,max_len=config.max_len) for i in reviews])
    labels = torch.LongTensor(labels)

    return reviews, labels

collate_fn第二种定义方式:

import config

def collate_fn(batch):
    """
    对batch数据进行处理
    :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
    :return: 元组
    """
    reviews,labels = zip(*batch)
    reviews = torch.LongTensor([config.ws.transform(i,max_len=config.max_len) for i in reviews])
    labels = torch.LongTensor(labels)

    return reviews,labels

5. 分析原因

根据报错信息可以查找错误来源在collate.py源码,错误就出现在default_collate()函数中。百度发现此源码的defaul_collate函数是DataLoader类默认的处理batch的方法,如果在定义DataLoader时没有使用collate_fn参数指定函数,就会默认调用以下源码中的方法。如果你出现了上述报错,应该就是此函数中出现了倒数第四行的错误

源码

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

此函数功能就是传入一个batch数据元组,元组中是每个数据是你定义的dataset类中__getitem__()方法返回的内容,元组长度就是你的batch_size设置的大小。但是DataLoader类中最终返回的可迭代对象的一个字段是将batch_size大小样本的相应字段拼接到一起得到的。

因此默认调用此方法时,第一次会进入倒数第二行语句return [default_collate(samples) for samples in transposed]将batch元组通过zip函数生成可迭代对象。然后通过迭代取出相同字段递归重新传入default_collate()函数中,此时取出第一个字段判断数据类型在以上所列类型中,则可正确返回dateset内容。

如果batch数据是按以上顺序进行处理,则不会出现以上错误。如果进行第二次递归之后元素的数据还不在所列数据类型中,则依然会进入下一次也就是第三次递归,此时就算能正常返回数据也不符合我们要求,而且报错一般就是出现在第三次递归及之后。因此想要解决此错误,需要仔细检查自己定义dataset类返回字段的数据类型。也可以在defaule_collate()方法中输出处理前后batch内容,查看函数具体处理流程,以帮助自己查找返回字段数据类型的错误。

友情提示: 不要在源码文件中更改defaule_collate()方法,可以把此代码copy出来,定义一个自己的collate_fn()函数并在实例化DataLoader类时指定自己定义的collate_fn函数。

6. 完整代码

"""
完成数据集的准备
"""
from torch.utils.data import DataLoader, Dataset
import os
import re
import torch

# 分词
def tokenlize(content):
    content = re.sub('<.*?>', ' ', content, flags=re.S)
    # filters = ['!', '"', '#', '$', '%', '&', '\\(', '\\)', '\\*', '\\+', ',', '-', '\\.', '/', ':', ';', '<', '=', '>', '\\?',
    #            '@', '\\[', '\\\\', '\\]', '^', '_', '`', '\\{', '\\|', '\\}', '~', '\\t', '\\n', '\\x97', '\\x96', '”', '“', ]
    filters = ['\\.', '\\t', '\\n', '\\x97', '\\x96', '#', '$', '%', '&']
    content = re.sub('|'.join(filters), ' ', content)
    tokens = [i.strip().lower() for i in content.split()]
    return tokens


# 准备dataset
class ImdbDataset(Dataset):
    def __init__(self, train=True):
        self.train_data_path = r'.\\aclImdb\\train'
        self.test_data_path = r'.\\aclImdb\\test'
        data_path = self.train_data_path if train else self.test_data_path

        # 把所有的文件名放入列表
        temp_data_path = [os.path.join(data_path, 'pos'), os.path.join(data_path, 'neg')]
        self.total_file_path = []  # 所有评论文件的path
        for path in temp_data_path:
            file_name_list = os.listdir(path)
            file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith('.txt')]
            self.total_file_path.extend(file_path_list)

    def __getitem__(self, idx):
        file_path = self.total_file_path[idx]
        # 获取label
        label_str = file_path.split('\\\\')[-2]
        label = 0 if label_str == 'neg' else 1
        # 获取内容
        # 分词
        tokens = tokenlize(open(file_path).read().strip())  # # 直接按照空格进行分词
        return label, tokens

    def __len__(self):
        return len(self.total_file_path)

def collate_fn(batch):
    #  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts


# 获取数据集加载器
def get_dataloader(train=True):
    imdb_dataset = ImdbDataset(train)
    data_loader = DataLoader(imdb_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
    return data_loader

# 观察输出结果
if __name__ == '__main__':
    for idx, (input, target) in enumerate(get_dataloader()):
        print('idx', idx)
        print('input', input)
        print('target', target)
        break

祝大家早日解决bug,跑通模型!

加油!

感谢!

努力!

以上是关于RuntimeError: each element in list of batch should be of equal size的主要内容,如果未能解决你的问题,请参考以下文章

每天一道LeetCode--169.Majority Elemen

10.17 elemen.js

elemen-table表格数据转换-formatter属性

html [css:center-a-position-absolute-element] #css

vue+elemen把时间作为参数搜索数据注意一点

vue + elemen可远程搜索select选择器的封装(思路及源码分享)