Lstm Cell in detail and how to implement it by pytorch

Posted quinn-yann

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Lstm Cell in detail and how to implement it by pytorch相关的知识,希望对你有一定的参考价值。

Refer to :

https://medium.com/@andre.holzner/lstm-cells-in-pytorch-fab924a78b1c

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

 技术分享图片

 

 

LSTM cells in PyTorch

This is an annotated illustration of the LSTM cell in PyTorch (admittedly inspired by the diagrams in Christopher Olah’s excellent blog article):

 

技术分享图片

 

技术分享图片

 

 

The yellow boxes correspond to matrix multiplication followed by non-linearities. W represent the weight matrices, the bias terms b have been omitted for simplicity. The mathematical symbols used in this diagram correspond to those used in PyTorch’s documentation of torch.nn.LSTM:

  • x(t): the external input (e.g. from training data) at time t
  • h(t-1)/h(t): the hidden state at times t-1 (‘input’) or t (‘output’). Despite its name, this is also used as output or used as input for a next layer of LSTM cells (for multi-layer networks)
  • c(t-1)/c(t): the ‘cell state’ or ‘memory’ at times t-1 and t
  • f(t): the result of the forget gate. For values close to zero the cell will ‘forget’ its memories c(t-1) from the past, for values close to one it will remember its history.
  • i(t): the result of the input gate, determining how important the (transformed) new external input is.
  • g(t): the result of the cell gate, a non-linear transformation of the new external input x(t)
  • o(t): the result of the output gate which controls how much of the new cell state c(t) should go to the output (and the hidden state)

It is also instructive to look at the implementation of torch.nn._functions.rnn.LSTMCell :

def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    if input.is_cuda:
        ...

    h_t_1, c_t_1 = hidden
    gates = F.linear(input, w_ih, b_ih) + F.linear(h_t_1, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate     = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate   = F.tanh(cellgate)
    outgate    = F.sigmoid(outgate)

    c_t = (forgetgate * c_t_1) + (ingate * cellgate)
    h_t = outgate * F.tanh(c_t)

    return h_t, c_t

 

The second argument (hidden) in fact is expected to be a tuple of: (ht-1, ct-1)

(hidden state at time t-1, cell/memory state at time t-1)

and the return value is of the same format but for time t.

 

以上是关于Lstm Cell in detail and how to implement it by pytorch的主要内容,如果未能解决你的问题,请参考以下文章

LSTM与GRU有什么联系和区别?

TensorFlow自主实现包含全节点Cell的LSTM层 Cell

ValueError:尝试共享变量 rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel

Mol Cell Proteomics. | Integration and analysis of CPTAC proteomics data in the context of cancer ge

在tensorflow中使用glstm(Group LSTM) cell构建双向rnn

lstm和gru结构的再理解