PyTorch nn.RNN 参数全解析
Posted raelum
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch nn.RNN 参数全解析相关的知识,希望对你有一定的参考价值。
目录
一、简介
torch.nn.RNN
用于构建循环层,其中的计算规则如下:
h t = tanh ( W i h x t + b i h + W h h h t − 1 + b h h ) (1) \\boldsymbolh_t=\\tanh(\\bf W_ih\\boldsymbolx_t+\\boldsymbolb_ih+\\bf W_hh\\boldsymbolh_t-1+\\boldsymbolb_hh) \\tag1 ht=tanh(Wihxt+bih+Whhht−1+bhh)(1)
其中 h t \\boldsymbolh_t ht 是 t t t 时刻的隐层状态, x t \\boldsymbolx_t xt 是 t t t 时刻的输入。下标 i i i 是 i n p u t input input 的简写,下标 h h h 是 h i d d e n hidden hidden 的简写。 W , b \\bf W,\\boldsymbolb W,b 分别是权重和偏置。
二、前置知识
先回顾一下普通的神经网络,我们在训练它的过程中通常会投喂一小批量的数据。不妨设 batch_size = N \\textbatch\\_size=N batch_size=N,则投喂的数据的形式为:
X = [ x 1 T ⋮ x N T ] N × d \\bf X= \\beginbmatrix \\boldsymbolx_1^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_N^\\text T \\endbmatrix_N\\times d X=⎣⎢⎡x1T⋮xNT⎦⎥⎤N×d
其中 x i = ( x i 1 , x i 2 , ⋯ , x i d ) T \\boldsymbolx_i=(x_i1,x_i2,\\cdots,x_id)^\\text T xi=(xi1,xi2,⋯,xid)T 为特征向量,维数为 d d d。
在处理序列问题中,我们会将词元转化成对应的特征向量。例如在处理一个英文句子时,我们通常会通过某种手段将每个单词转化为合适的特征向量。设序列(句子)长度为 L L L,于是在此情景下,一个句子可以表示为:
seq i = [ x i 1 T ⋮ x i L T ] L × d \\textseq_i= \\beginbmatrix \\boldsymbolx_i1^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_iL^\\text T \\endbmatrix_L\\times d seqi=⎣⎢⎡xi1T⋮xiLT⎦⎥⎤L×d
其中的每个 x i j , j = 1 , ⋯ , L \\boldsymbolx_ij,\\;j=1,\\cdots, L xij,j=1,⋯,L 都对应了句子 seq i \\textseq_i seqi 中的一个单词。在上述约定下,我们在 t t t 时刻投喂给RNN的数据为:
X t = [ x 1 t T ⋮ x N t T ] N × d (2) \\bf X_t= \\beginbmatrix \\boldsymbolx_1t^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_Nt^\\text T \\endbmatrix_N\\times d\\tag2 Xt=⎣⎢⎡x1tT⋮xNtT⎦⎥⎤N×d(2)
从而 ( 1 ) (1) (1) 式改写为
H t = tanh ( X t W i h + b i h + H t − 1 W h h + b h h ) (3) \\bf H_t=\\tanh(\\bf X_t\\bf W_ih+\\boldsymbolb_ih+\\bf H_t-1\\bf W_hh+\\boldsymbolb_hh)\\tag3 Ht=tanh(XtWih+bih+Ht−1Whh+bhh)(3)
其中 H t , H t − 1 \\bf H_t,\\bf H_t-1 Ht,Ht−1 的形状为 N × h N\\times h N×h, W i h \\bf W_ih Wih 的形状为 d × h d\\times h d×h, W h h \\bf W_hh Whh 的形状为 h × h h\\times h h×h, b i h , b h h \\boldsymbolb_ih,\\boldsymbolb_hh bih,bPyTorch建立RNN相关模型
深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expa