pytorch collate_fn 拒绝样本并产生另一个
Posted
技术标签:
【中文标题】pytorch collate_fn 拒绝样本并产生另一个【英文标题】:pytorch collate_fn reject sample and yield another 【发布时间】:2020-01-08 22:01:53 【问题描述】:我已经建立了一个数据集,我正在对正在加载的图像进行各种检查。然后我将此 DataSet 传递给 DataLoader。
在我的 DataSet 类中,如果图片未通过我的检查,我会将样本返回为 None,并且我有一个自定义 collate_fn 函数,该函数从检索到的批次中删除所有 None 并返回剩余的有效样本。
但是,此时返回的批次可能具有不同的大小。有没有办法告诉 collate_fn 保持采购数据,直到批量大小达到一定长度?
class DataSet():
def __init__(self, example):
# initialise dataset
# load csv file and image directory
self.example = example
def __getitem__(self,idx):
# load one sample
# if image is too dark return None
# else
# return one image and its equivalent label
dataset = Dataset(csv_file='../', image_dir='../../')
dataloader = DataLoader(dataset , batch_size=4,
shuffle=True, num_workers=1, collate_fn = my_collate )
def my_collate(batch): # batch size 4 [tensor image, tensor label,,,] could return something like G = [None, ,,]
batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [,,]
# I want len(G) = 4
# so how to sample another dataset entry?
return torch.utils.data.dataloader.default_collate(batch)
【问题讨论】:
【参考方案1】:有2个hack可以用来解决问题,选择一种方式:
使用原始批次样本快速选项:
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
diff = len_batch - len(batch)
for i in range(diff):
batch = batch + batch[:diff]
return torch.utils.data.dataloader.default_collate(batch)
否则只是从数据集中随机加载另一个样本更好的选择:
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # source all the required samples from the original dataset at random
diff = len_batch - len(batch)
for i in range(diff):
batch.append(dataset[np.random.randint(0, len(dataset))])
return torch.utils.data.dataloader.default_collate(batch)
【讨论】:
您将如何构造数据加载器 collate_fn 参数以使数据集在范围内? 感谢代码!我认为“更好的选择”也应该支持新样本也可能是无。所以我猜应该有类似while循环的东西。【参考方案2】:对于任何希望即时拒绝训练示例的人,无需使用技巧来解决数据加载器的 collate_fn 中的问题,只需使用 IterableDataset 并编写 __iter__ 和 __next__ 函数,如下所示
def __iter__(self):
return self
def __next__(self):
# load the next non-None example
【讨论】:
【参考方案3】:这对我有用,因为有时甚至那些随机值都是无。
def my_collate(batch):
len_batch = len(batch)
batch = list(filter(lambda x: x is not None, batch))
if len_batch > len(batch):
db_len = len(dataset)
diff = len_batch - len(batch)
while diff != 0:
a = dataset[np.random.randint(0, db_len)]
if a is None:
continue
batch.append(a)
diff -= 1
return torch.utils.data.dataloader.default_collate(batch)
【讨论】:
【参考方案4】:感谢 Brian Formento 提出解决问题的方法和建议。如前所述,用新示例替换坏示例的最佳选择有两个问题:
-
新采样的示例也可能已损坏;
数据集不在范围内。
这里有一个解决方案 - 问题 1 通过递归调用解决,问题 2 通过创建 collate 函数的部分函数并固定数据集。
import random
import torch
def collate_fn_replace_corrupted(batch, dataset):
"""Collate function that allows to replace corrupted examples in the
dataloader. It expect that the dataloader returns 'None' when that occurs.
The 'None's in the batch are replaced with another examples sampled randomly.
Args:
batch (torch.Tensor): batch from the DataLoader.
dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
Specify it with functools.partial and pass the resulting partial function that only
requires 'batch' argument to DataLoader's 'collate_fn' option.
Returns:
torch.Tensor: batch with new examples instead of corrupted ones.
"""
# Idea from https://***.com/a/57882783
original_batch_len = len(batch)
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
filtered_batch_len = len(batch)
# Num of corrupted examples
diff = original_batch_len - filtered_batch_len
if diff > 0:
# Replace corrupted examples with another examples randomly
batch.extend([dataset[random.randint(0, len(dataset))] for _ in range(diff)])
# Recursive call to replace the replacements if they are corrupted
return collate_fn_replace_corrupted(batch, dataset)
# Finally, when the whole batch is fine, return it
return torch.utils.data.dataloader.default_collate(batch)
但是,您不能将其直接传递给 DataLoader
,因为 collate 函数应该只有一个参数 - batch
。为此,我们使用指定的数据集创建一个偏函数,并将偏函数传递给DataLoader
。
import functools
from torch.utils.data import DataLoader
collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn)
【讨论】:
【参考方案5】:对于快速选项,它有问题。以下是固定版本。
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
diff = len_batch - len(batch)
batch = batch + batch[:diff] # assume diff < len(batch)
return torch.utils.data.dataloader.default_collate(batch)
【讨论】:
也许您想在您所做的事情中添加一些解释? 似乎原始答案中不需要for
循环。以上是关于pytorch collate_fn 拒绝样本并产生另一个的主要内容,如果未能解决你的问题,请参考以下文章