PyTorch-LSTM
Posted xidian-mao
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch-LSTM相关的知识,希望对你有一定的参考价值。
1 import torch 2 import torch.nn as nn 3 4 torch.random.manual_seed(10) 5 6 input_size = 2 # 输入向量维度 7 hidden_size = 4 # 隐层层维度 8 num_layers = 2 # 层数 9 10 lstm = nn.LSTM(input_size, hidden_size, num_layers) 11 12 13 # Input: 14 15 # input of shape (sep_len, bath, input_size) 16 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size) 17 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size) 18 19 # Output: 20 # output of shape (sep_len, bath, num_directions * hidden_size) 21 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size) 22 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size) 23 24 # two ways 25 Input = torch.randn(4, 3, 2) 26 h = torch.randn(2, 3, 4) 27 c = torch.randn(2, 3, 4) 28 output = None 29 30 # first 31 h1 = h 32 c1 = c 33 for it in Input: 34 output, (h1, c1) = lstm(it.view(1, 3, -1), (h1, c1)) 35 print((output == h1[-1]).all().item()) 36 print(output) 37 38 # second 39 output1, (h, c) = lstm(Input,(h, c)) 40 print(output1[-1]) 41 # print(output1[-1] == output) 精度的问题
以上是关于PyTorch-LSTM的主要内容,如果未能解决你的问题,请参考以下文章