Lstm Cell in detail and how to implement it by pytorch
Posted quinn-yann
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Lstm Cell in detail and how to implement it by pytorch相关的知识,希望对你有一定的参考价值。
Refer to :
https://medium.com/@andre.holzner/lstm-cells-in-pytorch-fab924a78b1c
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
LSTM cells in PyTorch
This is an annotated illustration of the LSTM cell in PyTorch (admittedly inspired by the diagrams in Christopher Olah’s excellent blog article):
The yellow boxes correspond to matrix multiplication followed by non-linearities. W represent the weight matrices, the bias terms b have been omitted for simplicity. The mathematical symbols used in this diagram correspond to those used in PyTorch’s documentation of torch.nn.LSTM:
- x(t): the external input (e.g. from training data) at time t
- h(t-1)/h(t): the hidden state at times t-1 (‘input’) or t (‘output’). Despite its name, this is also used as output or used as input for a next layer of LSTM cells (for multi-layer networks)
- c(t-1)/c(t): the ‘cell state’ or ‘memory’ at times t-1 and t
- f(t): the result of the forget gate. For values close to zero the cell will ‘forget’ its memories c(t-1) from the past, for values close to one it will remember its history.
- i(t): the result of the input gate, determining how important the (transformed) new external input is.
- g(t): the result of the cell gate, a non-linear transformation of the new external input x(t)
- o(t): the result of the output gate which controls how much of the new cell state c(t) should go to the output (and the hidden state)
It is also instructive to look at the implementation of torch.nn._functions.rnn.LSTMCell :
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): if input.is_cuda: ... h_t_1, c_t_1 = hidden gates = F.linear(input, w_ih, b_ih) + F.linear(h_t_1, w_hh, b_hh) ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ingate = F.sigmoid(ingate) forgetgate = F.sigmoid(forgetgate) cellgate = F.tanh(cellgate) outgate = F.sigmoid(outgate) c_t = (forgetgate * c_t_1) + (ingate * cellgate) h_t = outgate * F.tanh(c_t) return h_t, c_t
The second argument (hidden) in fact is expected to be a tuple of: (ht-1, ct-1)
(hidden state at time t-1, cell/memory state at time t-1)
and the return value is of the same format but for time t.
以上是关于Lstm Cell in detail and how to implement it by pytorch的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow自主实现包含全节点Cell的LSTM层 Cell
ValueError:尝试共享变量 rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel
Mol Cell Proteomics. | Integration and analysis of CPTAC proteomics data in the context of cancer ge