PyTorch nn.RNN 参数全解析

Posted raelum

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch nn.RNN 参数全解析相关的知识,希望对你有一定的参考价值。

目录

一、简介

torch.nn.RNN 用于构建循环层,其中的计算规则如下:

h t = tanh ⁡ ( W i h x t + b i h + W h h h t − 1 + b h h ) (1) \\boldsymbolh_t=\\tanh(\\bf W_ih\\boldsymbolx_t+\\boldsymbolb_ih+\\bf W_hh\\boldsymbolh_t-1+\\boldsymbolb_hh) \\tag1 ht=tanh(Wihxt+bih+Whhht1+bhh)(1)

其中 h t \\boldsymbolh_t ht t t t 时刻的隐层状态, x t \\boldsymbolx_t xt t t t 时刻的输入。下标 i i i i n p u t input input 的简写,下标 h h h h i d d e n hidden hidden 的简写。 W , b \\bf W,\\boldsymbolb W,b 分别是权重和偏置。

二、前置知识

先回顾一下普通的神经网络,我们在训练它的过程中通常会投喂一小批量的数据。不妨设 batch_size = N \\textbatch\\_size=N batch_size=N,则投喂的数据的形式为:

X = [ x 1 T ⋮ x N T ] N × d \\bf X= \\beginbmatrix \\boldsymbolx_1^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_N^\\text T \\endbmatrix_N\\times d X=x1TxNTN×d

其中 x i = ( x i 1 , x i 2 , ⋯   , x i d ) T \\boldsymbolx_i=(x_i1,x_i2,\\cdots,x_id)^\\text T xi=(xi1,xi2,,xid)T 为特征向量,维数为 d d d

在处理序列问题中,我们会将词元转化成对应的特征向量。例如在处理一个英文句子时,我们通常会通过某种手段将每个单词转化为合适的特征向量。设序列(句子)长度为 L L L,于是在此情景下,一个句子可以表示为:

seq i = [ x i 1 T ⋮ x i L T ] L × d \\textseq_i= \\beginbmatrix \\boldsymbolx_i1^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_iL^\\text T \\endbmatrix_L\\times d seqi=xi1TxiLTL×d

其中的每个 x i j ,    j = 1 , ⋯   , L \\boldsymbolx_ij,\\;j=1,\\cdots, L xij,j=1,,L 都对应了句子 seq i \\textseq_i seqi 中的一个单词。在上述约定下,我们 t t t 时刻投喂给RNN的数据为:

X t = [ x 1 t T ⋮ x N t T ] N × d (2) \\bf X_t= \\beginbmatrix \\boldsymbolx_1t^\\text T \\\\ \\vdots \\\\ \\boldsymbolx_Nt^\\text T \\endbmatrix_N\\times d\\tag2 Xt=x1tTxNtTN×d(2)

从而 ( 1 ) (1) (1) 式改写为

H t = tanh ⁡ ( X t W i h + b i h + H t − 1 W h h + b h h ) (3) \\bf H_t=\\tanh(\\bf X_t\\bf W_ih+\\boldsymbolb_ih+\\bf H_t-1\\bf W_hh+\\boldsymbolb_hh)\\tag3 Ht=tanh(XtWih+bih+Ht1Whh+bhh)(3)

其中 H t , H t − 1 \\bf H_t,\\bf H_t-1 Ht,Ht1 的形状为 N × h N\\times h N×h W i h \\bf W_ih Wih 的形状为 d × h d\\times h d×h W h h \\bf W_hh Whh 的形状为 h × h h\\times h h×h b i h , b h h \\boldsymbolb_ih,\\boldsymbolb_hh bih,bPyTorch建立RNN相关模型

是否使用 nn.RNN 的代码差异

教师用 pytorch RNN 强制

深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expa

tf.squeeze和tf.nn.rnn的功能是什么?

Tensorflow 与 Keras 中的 RNN,tf.nn.dynamic_rnn() 的贬值