如何使用 Pytorch 中的截断反向传播(闪电)在很长的序列上运行 LSTM?

Posted

技术标签:

【中文标题】如何使用 Pytorch 中的截断反向传播(闪电)在很长的序列上运行 LSTM?【英文标题】:How to run LSTM on very long sequence using Truncated Backpropagation in Pytorch (lightning)? 【发布时间】:2021-06-28 08:05:30 【问题描述】:

我有一个很长的时间序列,我想将其输入 LSTM 以进行每帧分类。

我的数据是按帧标记的,我知道一些罕见的事件发生后会严重影响分类。

因此,我必须输入整个序列才能获得有意义的预测。

众所周知,将非常长的序列输入 LSTM 是次优的,因为梯度会像正常的 RNN 一样消失或爆炸。


我想使用一种简单的技术将序列切割成更短(比如 100 长)的序列,并在每个序列上运行 LSTM,然后将最终的 LSTM 隐藏状态和单元状态作为起始隐藏状态和单元状态传递给下一次向前传球。

Here 是我发现的一个例子。在那里它被称为“通过时间截断的反向传播”。我无法为我做同样的工作。


我对 Pytorch 闪电的尝试(去掉了不相关的部分):

def __init__(self, config, n_classes, datamodule):
    ...
    self._criterion = nn.CrossEntropyLoss(
        reduction='mean',
    )

    num_layers = 1
    hidden_size = 50
    batch_size=1

    self._lstm1 = nn.LSTM(input_size=len(self._in_features), hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self._log_probs = nn.Linear(hidden_size, self._n_predicted_classes)
    self._last_h_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)
    self._last_c_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)

def training_step(self, batch, batch_index):
    orig_batch, label_batch = batch
    n_labels_in_batch = np.prod(label_batch.shape)
    lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch, (self._last_h_n, self._last_c_n))
    log_probs = self._log_probs(lstm_out)
    loss = self._criterion(log_probs.view(n_labels_in_batch, -1), label_batch.view(n_labels_in_batch))

    return loss

运行此代码会出现以下错误:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

如果我添加也会发生同样的情况

def on_after_backward(self) -> None:
    self._last_h_n.detach()
    self._last_c_n.detach()

如果我使用该错误不会发生

lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch,)

但这显然没有用,因为当前帧批次的输出不会转发到下一个。


是什么导致了这个错误?我认为分离输出h_nc_n 就足够了。

如何将前一帧批次的输出传递给下一帧,并让 Torch 分别反向传播每个批次?

【问题讨论】:

【参考方案1】:

显然,我错过了_ 后面的detach()

使用

def on_after_backward(self) -> None:
    self._last_h_n.detach_()
    self._last_c_n.detach_()

有效。


问题是self._last_h_n.detach() 没有更新对由 detach() 分配的 新内存 的引用,因此图形仍然取消引用反向传播所经过的旧变量。 The reference answer 通过 H = H.detach() 解决了这个问题。

更干净(可能更快)的是self._last_h_n.detach_(),它会在原地完成操作。

【讨论】:

以上是关于如何使用 Pytorch 中的截断反向传播(闪电)在很长的序列上运行 LSTM?的主要内容,如果未能解决你的问题,请参考以下文章

了解 PyTorch 中的反向传播

学习笔记Pytorch十二损失函数与反向传播

Pytorch Note13 反向传播算法

pytorch前向传播和反向传播

pytorch中的Variable()

torch教程[3] 使用pytorch自带的反向传播