pytorch“试图第二次向后遍历图形”字符级RNN错误

Posted

技术标签:

【中文标题】pytorch“试图第二次向后遍历图形”字符级RNN错误【英文标题】:pytorch "trying to backward through the graph a second time" error with chracter level RNN 【发布时间】:2020-11-07 17:49:28 【问题描述】:

我正在使用 pytorch 训练字符级 GRU,同时将文本分成一定块长度的批次。 这是训练循环:

for e in range(self.epochs):
  self.model.train()
  h = self.get_init_state(self.batch_size)
  
  for batch_num in range(self.num_batch_runs):
    batch = self.generate_batch(batch_num).to(device)
    
    inp_batch = batch[:-1,:]
    tar_batch = batch[1:,:]
    
    
    self.model.zero_grad()
    loss = 0

    for i in range(inp_batch.shape[0]):
      out, h = self.model(inp_batch[i:i+1,:],h)

      loss += loss_fn(out[0],tar_batch[i].view(-1))
      
    
    loss.backward()

    nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)

    optimizer.step()
    

    if not (batch_num % 5):
      print("epoch: , loss: ".format(e,loss.data.item()/inp_batch.shape[0]))

不过,我在第一批之后收到此错误:

Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

提前谢谢..

【问题讨论】:

这能回答你的问题吗? Pytorch - RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed 【参考方案1】:

我自己找到了答案,GRU的隐藏状态还是附着在上一次batch运行的,所以必须使用分离

h.detach_()

【讨论】:

以上是关于pytorch“试图第二次向后遍历图形”字符级RNN错误的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch - RuntimeError:尝试第二次向后退,但缓冲区已被释放

PyTorch-14 使用字符级RNN分类名字(姓名)

基于pytorch的LSTM进行字符级文本生成实战

Pytorch系列教程-使用字符级RNN生成姓名

PyTorch-15 使用字符级RNN生成名字(姓名)

pytorch 入门级项目的感悟