RNN流程理解
Posted 三つ叶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RNN流程理解相关的知识,希望对你有一定的参考价值。
import torch.nn as nn
import torch
rnn = nn.RNN(10, 20, 2) # 约定了此 RNN 的输入 input_feature_size=10, output_feature_size=20, num_layers=2
print(rnn)
input = torch.randn(5, 3, 10) # input=(input_sequence, batch_size, input_feature_size)
h0 = torch.randn(2, 3, 20) # h0=(num_layers, batch_size, output_feature_size)
output, hn = rnn(input, h0)
print(output.shape) # output=(input_sequence, batch_size, output_feature_size)
print(hn.shape) # hn=(num_layers, batch_size, output_feature_size)
以下流程对照参考这一节内容
X − − − > ( S e q u e n c e _ l e n g t h , b a t c h _ s i z e , i n p u t _ f e a t u r e _ s i z e ) X 1 , X 2 … … X s e q u e n c e _ l e n g t h − − − > ( b a t c h _ s i z e , i n p u t _ f e a t u r e _ s i z e ) W x h − − > ( i n p u t _ f e a t u r e _ s i z e , h i d d e n _ l a y e r s ) W h h − − > ( h i d d e n _ l a y e r s , h i d d e n _ l a y e r s ) W h q − − > ( h i d d e n _ l a y e r s , o u t p u t _ f e a t u r e _ s i z e ) X--->(Sequence\\_length, batch\\_size, input\\_feature\\_size) \\\\ X_1, X_2 …… X_sequence\\_length --->(batch\\_size, input\\_feature\\_size) \\\\ W_xh --> (input\\_feature\\_size, hidden\\_layers) \\\\ W_hh --> (hidden\\_layers, hidden\\_layers) \\\\ W_hq --> (hidden\\_layers, output\\_feature\\_size) \\\\ X−−−>(Sequence_length,batch_size,input_feature_size)X1,X2……Xsequence_length−−−>(batch_size,input_feature_size)Wxh−−>(input_feature_size,hidden_layers)Whh−−>(hidden_layers,hidden_layers)Whq−−>(hidden_layers,output_feature_size)
归根到底,RNN 其实仍旧是神经网络,不要想得太复杂
在 Pytorch 中 hidden_layers 和 output_feature_size 是相同的
以上是关于RNN流程理解的主要内容,如果未能解决你的问题,请参考以下文章