pytorch 数据加载器和/或 __getitem__ 函数中的浅拷贝和深拷贝
Posted
技术标签:
【中文标题】pytorch 数据加载器和/或 __getitem__ 函数中的浅拷贝和深拷贝【英文标题】:shallow and deep copies in pytorch dataloader and/or __getitem__ function 【发布时间】:2022-01-17 21:42:22 【问题描述】:我在使用自定义 pytorch 数据加载器时遇到了问题,我认为这与 __getitem__()
函数中的浅拷贝和深拷贝有关。但是,有些行为我不明白。而且我不知道它是来自 pytorch 数据加载器类还是其他地方。
我根据自己的复杂用例创建了一个最小的工作示例。最初,我有一个保存为.hdf5
的数据集,我将其加载到__init__()
中。对于 NN,我希望将元素归一化为 1(我除以它们的总和)并分别返回总和。 :
# imports
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
# create dataset with fixed seed
np.random.seed(1234)
data = np.random.rand(20, 4)
print(data)
# create custom dataset class
class TestDataset(Dataset):
""" Test dataset to illustrate bug in get_item """
def __init__(self, data_array, transform=None, apply_logit=True, with_noise=False):
"""
Args:
data_array (np.array): representing data loaded from hdf5 file or so
transform (None, callable or 'norm'): if data should be transformed
apply_logit (bool): if logit transform should be applied at the end
with_noise (bool): if noise should be applied in each call
"""
self.data = data_array
self.transform = transform
self.apply_logit = apply_logit
self.with_noise = with_noise
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
data = self.data[idx]
if self.with_noise:
data = add_noise(data)
data_sum = data.sum(axis=(-1), keepdims=True)
if self.transform:
if self.transform == 'norm':
data /= (data_sum + 1e-16) # this should be avoided
else:
data = self.transform(data)
if self.apply_logit:
data = logit_trafo(data)
sample = 'data': data, 'data_sum': data_sum.squeeze()
return sample
def get_dataloader(data_array, device, batch_size=2, apply_logit=True, with_noise=False, normed=False):
kwargs = 'num_workers': 2, 'pin_memory': True if device.type is 'cuda' else
dataset = TestDataset(data_array, transform='norm' if normed else None, apply_logit=apply_logit,
with_noise=with_noise)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)
def add_noise(input_tensor):
noise = np.random.rand(*input_tensor.shape)*1e-6
return input_tensor+noise
ALPHA = 1e-6
def logit(x):
return np.log(x / (1.0 - x))
def logit_trafo(x):
local_x = ALPHA + (1. - 2.*ALPHA) * x
return logit(local_x)
# with_noise=False will print just [1. 1.] after one epoch (due to the /= operation above)
# with_noise=True will remove this issue. Why?
mydata = get_dataloader(data, torch.device('cpu'), apply_logit=False, with_noise=False, normed=True)
with torch.no_grad():
for n in range(3):
print("epoch: ", n)
for i, elem in enumerate(mydata):
print('batch: ', i, #elem['data'].numpy(),
elem['data_sum'].numpy())
我得到以下输出:
[[0.19151945 0.62210877 0.43772774 0.78535858]
[0.77997581 0.27259261 0.27646426 0.80187218]
[0.95813935 0.87593263 0.35781727 0.50099513]
[0.68346294 0.71270203 0.37025075 0.56119619]
[0.50308317 0.01376845 0.77282662 0.88264119]
[0.36488598 0.61539618 0.07538124 0.36882401]
[0.9331401 0.65137814 0.39720258 0.78873014]
[0.31683612 0.56809865 0.86912739 0.43617342]
[0.80214764 0.14376682 0.70426097 0.70458131]
[0.21879211 0.92486763 0.44214076 0.90931596]
[0.05980922 0.18428708 0.04735528 0.67488094]
[0.59462478 0.53331016 0.04332406 0.56143308]
[0.32966845 0.50296683 0.11189432 0.60719371]
[0.56594464 0.00676406 0.61744171 0.91212289]
[0.79052413 0.99208147 0.95880176 0.79196414]
[0.28525096 0.62491671 0.4780938 0.19567518]
[0.38231745 0.05387369 0.45164841 0.98200474]
[0.1239427 0.1193809 0.73852306 0.58730363]
[0.47163253 0.10712682 0.22921857 0.89996519]
[0.41675354 0.53585166 0.00620852 0.30064171]]
epoch: 0
batch: 0 [2.03671454 2.13090485]
batch: 1 [2.69288438 2.3276119 ]
batch: 2 [2.17231943 1.42448741]
batch: 3 [2.77045097 2.19023559]
batch: 4 [2.35475675 2.49511645]
batch: 5 [0.96633253 1.73269209]
batch: 6 [1.5517233 2.1022733]
batch: 7 [3.5333715 1.58393664]
batch: 8 [1.86984429 1.56915029]
batch: 9 [1.70794311 1.25945542]
epoch: 1
batch: 0 [1. 1.]
batch: 1 [1. 1.]
batch: 2 [1. 1.]
batch: 3 [1. 1.]
batch: 4 [1. 1.]
batch: 5 [1. 1.]
batch: 6 [1. 1.]
batch: 7 [1. 1.]
batch: 8 [1. 1.]
batch: 9 [1. 1.]
epoch: 2
batch: 0 [1. 1.]
batch: 1 [1. 1.]
batch: 2 [1. 1.]
batch: 3 [1. 1.]
batch: 4 [1. 1.]
batch: 5 [1. 1.]
batch: 6 [1. 1.]
batch: 7 [1. 1.]
batch: 8 [1. 1.]
batch: 9 [1. 1.]
在第一个 epoch 之后,应该给出每个输入向量之和的条目返回 1。根据我的理解,原因是 __getitem()__
内部的 /=
操作覆盖了原始数组(因为它只是一个浅拷贝)。但是,当我使用with_noise=True
创建数据加载器时,输出变为
epoch: 0
batch: 0 [2.03671714 2.13090728]
batch: 1 [2.69288618 2.32761437]
batch: 2 [2.17232151 1.42449024]
batch: 3 [2.7704527 2.19023717]
batch: 4 [2.35475926 2.49511859]
batch: 5 [0.96633553 1.73269352]
batch: 6 [1.55172434 2.10227475]
batch: 7 [3.53337356 1.58393908]
batch: 8 [1.86984558 1.56915276]
batch: 9 [1.70794503 1.25945833]
epoch: 1
batch: 0 [2.03671729 2.13090765]
batch: 1 [2.69288721 2.32761405]
batch: 2 [2.17232208 1.42449008]
batch: 3 [2.77045253 2.19023718]
batch: 4 [2.35475815 2.4951189 ]
batch: 5 [0.96633595 1.73269401]
batch: 6 [1.55172476 2.10227547]
batch: 7 [3.53337382 1.58393882]
batch: 8 [1.86984584 1.56915165]
batch: 9 [1.70794547 1.25945795]
epoch: 2
batch: 0 [2.03671533 2.13090593]
batch: 1 [2.69288633 2.32761373]
batch: 2 [2.17232158 1.42448975]
batch: 3 [2.77045371 2.19023796]
batch: 4 [2.3547586 2.49511857]
batch: 5 [0.96633348 1.73269476]
batch: 6 [1.55172544 2.10227616]
batch: 7 [3.53337367 1.58393892]
batch: 8 [1.86984568 1.56915256]
batch: 9 [1.70794379 1.25945825]
如果我添加的噪声乘以0.
,也是如此。
这是为什么呢?怎么突然变成深拷贝了?
【问题讨论】:
能否请您追查问题并制作一个实际的minimal reproducible example?x /= y
是一个 in-place 运算符:它修改它所操作的对象x
,而不是创建一个新对象,除非该对象是不可变的。 x = x / y
不就位:它创建一个新对象并将其绑定到名称x
。这有帮助吗?
上面的代码是完整的,它重现了这个问题。我知道最初的错误来自就地操作 x /= y (我也写过)。我的问题是为什么当我调用 add_noise() 函数时它会突然改变(即使我在那里添加 0)。
sn-p 代码过多,如果您花时间将其删减,则更短的代码会重现该问题。
input_tensor+noise
创建一个新数组。 input_tensor+=noise
不会。然后,您必须在单独的行上返回 input_tensor
。我在我的例子中使用了/
,但原理是一样的。
【参考方案1】:
谢谢你,疯狂的物理学家!我不得不阅读它和代码几次才能看到问题:
如果不调用add_noise()
,data /= (data_sum + 1e-16)
行就位,会更改原始输入数组。因此,对它的每次后续调用都会返回已经标准化的数据。对add_noise()
的调用会按照其编码方式创建一个新数组。然后就地操作仅更改新数组并且不触及原始数组(这是我错过的步骤)。因此后续调用会返回原始的、未标准化的数组。
【讨论】:
以上是关于pytorch 数据加载器和/或 __getitem__ 函数中的浅拷贝和深拷贝的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch学习6《PyTorch深度学习实践》——加载数据集(Dataset and DataLoader)