pytorch 后向误差,由就地操作修改的梯度计算变量之一
Posted
技术标签:
【中文标题】pytorch 后向误差,由就地操作修改的梯度计算变量之一【英文标题】:pytorch backward error, one of variables for gradient computation modified by an inplace operation 【发布时间】:2020-11-04 17:10:55 【问题描述】:我是 pytorch 的新手,我一直在尝试实现一个文本摘要网络。当我调用 loss.backward() 时出现错误。
RuntimeError:梯度计算所需的变量之一已被inplace操作修改:[torch.FloatTensor [10, 1, 1, 400]],即UnsqueezeBackward0的输出0,版本为 98;而是预期的版本 97。提示:上面的回溯显示了未能计算其梯度的操作。有问题的变量在那里或以后的任何地方都被改变了。祝你好运!
这是一个 seq2seq 模型,我认为问题出在这段代码 sn-p
final_dists=torch.zeros((batch_size,dec_max_len,extended_vsize)) #to hold the model outputs with extended vocab
attn_dists=torch.zeros((batch_size,dec_max_len,enc_max_len)) #to retain the attention weights over decoder steps
coverages=torch.zeros((batch_size,dec_max_len,enc_max_len)) #the coverages are retained to compute coverage loss
inp=self.emb_dropout(self.embedding(dec_batch[:,0])) #starting input: <SOS> shape [batch_size]
#self.prev_coverage is the accumulated coverage
coverage=None #initially none, but accumulates
with torch.autograd.set_detect_anomaly(True):
for i in range(1,dec_max_len):
#NOTE: the outputs, atten_dists, p_gens assignments start from i=1 (DON'T FORGET!)
vocab_dists,hidden,attn_dists_tmp,p_gen,coverage=self.decoder(inp,hidden,enc_outputs,enc_lens,coverage)
attn_dists[:,i,:]=attn_dists_tmp.squeeze(1)
coverages[:,i,:]=coverage.squeeze(1)
#vocab_dists: [batch_size, 1, dec_vocab_size] Note: this is the normalized probability
#hidden: [1,batch_size, dec_hid_dim]
#attn_dists_tmp: [batch_size, 1, enc_max_len]
#p_gen: [batch_size, 1]
#coverage: [batch_size, 1, enc_max_len]
#===================================================================
#To compute the final dist in pointer-generator network by extending vocabulary
vocab_dists_p=p_gen.unsqueeze(-1)*vocab_dists #[batch_size,1,dec_vocab_size] note we want to maintain vocab_dists for teacher_forcing_ratio
attn_dists_tmp=(1-p_gen).unsqueeze(-1)*attn_dists_tmp #[batch_size, 1, enc_max_len] note we want to maintain attn_dists for later use
extra_zeros=torch.zeros((batch_size,1,max_art_oovs)).to(self.device)
vocab_dists_extended=torch.cat((vocab_dists_p,extra_zeros),dim=2) #[batch_size, 1, extended_vsize]
attn_dists_projected=torch.zeros((batch_size,1,extended_vsize)).to(self.device)
indices=enc_batch_extend_vocab.clone().unsqueeze(1) #[batch_size, 1,enc_max_size]
attn_dists_projected=attn_dists_projected.scatter(2,indices,attn_dists_tmp)
#We need this otherwise we would modify a leaf Variable inplace
#attn_dists_projected_clone=attn_dists_projected.clone()
#attn_dists_projected_clone.scatter_(2,indices,attn_dists_tmp) #this will project the attention weights
#attn_dists_projected.scatter_(2,indices,attn_dists_tmp)
final_dists[:,i,:]=vocab_dists_extended.squeeze(1)+attn_dists_projected.squeeze(1)
#===================================================================
#teacher forcing, whether or not should use pred or dec sequence label
if random.random()<teacher_forcing_ratio:
inp=self.emb_dropout(self.embedding(dec_batch[:,i]))
else:
inp=self.emb_dropout(self.embedding(vocab_dists.squeeze(1).argmax(1)))
如果我删除 for 循环,并且只执行更新 attn_dists[:,1,:] 等的一步,并从 forward 返回的输出中丢失玩具,那么它可以正常工作。 有人知道这里有什么问题吗?这里没有就地操作。非常感谢!
【问题讨论】:
【参考方案1】:通过查看您的代码,问题可能来自以下几行:
attn_dists[:,i,:]=attn_dists_tmp.squeeze(1)
coverages[:,i,:]=coverage.squeeze(1)
您正在执行与 pytorch 为反向传播创建的图形冲突的就地操作。应该通过在每个循环中连接新信息来解决它(您可能很快就会耗尽内存!)
attn_dists = torch.cat((attn_dists, attn_dists_tmp.squeeze(1)), dim=1)
coverages = torch.cat(coverages, coverage.squeeze(1)),dim=1)
你也应该改变它们的初始化,否则你最终会得到一个两倍于你所考虑的大小的张量。
【讨论】:
感谢回复,torch.cat 会扩展维度,这不是我想要的。此外, attn_dists[:,i,:]=attn_dists_tmp.squeeze(1) 将切片分配给不同的变量。我在 seq2seq 示例中看到,其中创建了输出,并将输出 [i] 切片分配给循环中的预测。我认为这里没有区别。还是我错过了什么?非常感谢 我试过了,同样的问题,我不认为这是切片分配。并感谢使用 torch.cat 的建议,代码看起来会更干净。 我只注意到这一行:attn_dists_tmp=(1-p_gen).unsqueeze(-1)*attn_dists_tmp
你试过把它改成attn_dists_tmp2=(1-p_gen).unsqueeze(-1)*attn_dists_tmp
吗?您应该寻找就地操作。
是的,我也试过 attn_dists_tmp.clone(),同样的错误以上是关于pytorch 后向误差,由就地操作修改的梯度计算变量之一的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch LSTM-VAE Sentence Generator:RuntimeError:梯度计算所需的变量之一已被就地操作修改
RuntimeError:梯度计算所需的变量之一已被就地操作修改