RuntimeError: each element in list of batch should be of equal size
Posted ZSYL
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RuntimeError: each element in list of batch should be of equal size相关的知识,希望对你有一定的参考价值。
RuntimeError: each element in list of batch should be of equal size
1. 示例代码
"""
完成数据集的准备
"""
from torch.utils.data import DataLoader, Dataset
import os
import re
# 分词
def tokenlize(content):
content = re.sub('<.*?>', ' ', content, flags=re.S)
filters = ['!', '"', '#', '$', '%', '&', '\\(', '\\)', '\\*', '\\+', ',', '-', '\\.', '/', ':', ';', '<', '=', '>', '\\?',
'@', '\\[', '\\\\', '\\]', '^', '_', '`', '\\{', '\\|', '\\}', '~', '\\t', '\\n', '\\x97', '\\x96', '”', '“', ]
content = re.sub('|'.join(filters), ' ', content)
tokens = [i.strip().lower() for i in content.split()]
return tokens
# 准备dataset
class ImdbDataset(Dataset):
def __init__(self, train=True):
self.train_data_path = r'E:\\Python资料\\视频\\Py5.0\\00.8-12课件资料V5.0\\阶段9-人工智能NLP项目\\第四天\\代码\\data\\aclImdb_v1\\aclImdb\\train'
self.test_data_path = r'E:\\Python资料\\视频\\Py5.0\\00.8-12课件资料V5.0\\阶段9-人工智能NLP项目\\第四天\\代码\\data\\aclImdb_v1\\aclImdb\\test'
data_path = self.train_data_path if train else self.test_data_path
# 把所有的文件名放入列表
temp_data_path = [os.path.join(data_path, 'pos'), os.path.join(data_path, 'neg')]
self.total_file_path = [] # 所有评论文件的path
for path in temp_data_path:
file_name_list = os.listdir(path)
file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith('.txt')]
self.total_file_path.extend(file_path_list)
def __getitem__(self, idx):
file_path = self.total_file_path[idx]
# 获取了label
label_str = file_path.split('\\\\')[-2]
label = 0 if label_str == 'neg' else 1
# 获取内容
# 分词
tokens = tokenlize(open(file_path).read())
return tokens, label
def __len__(self):
return len(self.total_file_path)
# 获取数据集加载器
def get_dataloader(train=True):
imdb_dataset = ImdbDataset(train)
print(imdb_dataset[1])
data_loader = DataLoader(imdb_dataset, batch_size=2, shuffle=True)
return data_loader
# 观察输出结果
if __name__ == '__main__':
for idx, (input, target) in enumerate(get_dataloader()):
print('idx', idx)
print('input', input)
print('target', target)
break
2. 运行结果
3. 报错原因
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)
如果把batch_size=2
改为batch_size=1
时就不再报错了,运行结果如下:
4. batch_size=2
但是,如果想让batch_size=2
时,那该如何解决呢?
解决方法:
出现问题的原因在于Dataloader中的参数collate_fn
collate_fn
的默认值为torch自定义的default_collate
, collate_fn
的作用就是对每个batch进行处理,而默认的default_collate
处理出错。
解决思路:
- 考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误.
- 考虑自定义一个collate_fn,观察结果.
这里使用方式2,自定义一个collate_fn
,然后观察结果:
def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
reviews = torch.LongTensor([config.ws.transform(i,max_len=config.max_len) for i in reviews])
labels = torch.LongTensor(labels)
return reviews, labels
collate_fn
第二种定义方式:
import config
def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
reviews = torch.LongTensor([config.ws.transform(i,max_len=config.max_len) for i in reviews])
labels = torch.LongTensor(labels)
return reviews,labels
5. 分析原因
根据报错信息可以查找错误来源在collate.py
源码,错误就出现在default_collate()
函数中。百度发现此源码的defaul_collate
函数是DataLoader类默认的处理batch的方法,如果在定义DataLoader时没有使用collate_fn
参数指定函数,就会默认调用以下源码中的方法。如果你出现了上述报错,应该就是此函数中出现了倒数第四行的错误
源码:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
此函数功能就是传入一个batch数据元组,元组中是每个数据是你定义的dataset类中__getitem__()
方法返回的内容,元组长度就是你的batch_size设置的大小。但是DataLoader类中最终返回的可迭代对象的一个字段是将batch_size大小样本的相应字段拼接到一起得到的。
因此默认调用此方法时,第一次会进入倒数第二行语句return [default_collate(samples) for samples in transposed]
将batch元组通过zip函数生成可迭代对象。然后通过迭代取出相同字段递归重新传入default_collate()
函数中,此时取出第一个字段判断数据类型在以上所列类型中,则可正确返回dateset内容。
如果batch数据是按以上顺序进行处理,则不会出现以上错误。如果进行第二次递归之后元素的数据还不在所列数据类型中,则依然会进入下一次也就是第三次递归,此时就算能正常返回数据也不符合我们要求,而且报错一般就是出现在第三次递归及之后。因此想要解决此错误,需要仔细检查自己定义dataset类返回字段的数据类型。也可以在defaule_collate()
方法中输出处理前后batch内容,查看函数具体处理流程,以帮助自己查找返回字段数据类型的错误。
友情提示: 不要在源码文件中更改
defaule_collate()
方法,可以把此代码copy出来,定义一个自己的collate_fn()
函数并在实例化DataLoader类时指定自己定义的collate_fn
函数。
6. 完整代码
"""
完成数据集的准备
"""
from torch.utils.data import DataLoader, Dataset
import os
import re
import torch
# 分词
def tokenlize(content):
content = re.sub('<.*?>', ' ', content, flags=re.S)
# filters = ['!', '"', '#', '$', '%', '&', '\\(', '\\)', '\\*', '\\+', ',', '-', '\\.', '/', ':', ';', '<', '=', '>', '\\?',
# '@', '\\[', '\\\\', '\\]', '^', '_', '`', '\\{', '\\|', '\\}', '~', '\\t', '\\n', '\\x97', '\\x96', '”', '“', ]
filters = ['\\.', '\\t', '\\n', '\\x97', '\\x96', '#', '$', '%', '&']
content = re.sub('|'.join(filters), ' ', content)
tokens = [i.strip().lower() for i in content.split()]
return tokens
# 准备dataset
class ImdbDataset(Dataset):
def __init__(self, train=True):
self.train_data_path = r'.\\aclImdb\\train'
self.test_data_path = r'.\\aclImdb\\test'
data_path = self.train_data_path if train else self.test_data_path
# 把所有的文件名放入列表
temp_data_path = [os.path.join(data_path, 'pos'), os.path.join(data_path, 'neg')]
self.total_file_path = [] # 所有评论文件的path
for path in temp_data_path:
file_name_list = os.listdir(path)
file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith('.txt')]
self.total_file_path.extend(file_path_list)
def __getitem__(self, idx):
file_path = self.total_file_path[idx]
# 获取label
label_str = file_path.split('\\\\')[-2]
label = 0 if label_str == 'neg' else 1
# 获取内容
# 分词
tokens = tokenlize(open(file_path).read().strip()) # # 直接按照空格进行分词
return label, tokens
def __len__(self):
return len(self.total_file_path)
def collate_fn(batch):
# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
batch = list(zip(*batch))
labels = torch.tensor(batch[0], dtype=torch.int32)
texts = batch[1]
del batch
return labels, texts
# 获取数据集加载器
def get_dataloader(train=True):
imdb_dataset = ImdbDataset(train)
data_loader = DataLoader(imdb_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
return data_loader
# 观察输出结果
if __name__ == '__main__':
for idx, (input, target) in enumerate(get_dataloader()):
print('idx', idx)
print('input', input)
print('target', target)
break
祝大家早日解决bug,跑通模型!
加油!
感谢!
努力!
以上是关于RuntimeError: each element in list of batch should be of equal size的主要内容,如果未能解决你的问题,请参考以下文章
每天一道LeetCode--169.Majority Elemen
elemen-table表格数据转换-formatter属性