CNN之上的LSTM

Posted

技术标签:

【中文标题】CNN之上的LSTM【英文标题】:LSTM on top of CNN 【发布时间】:2016-10-11 10:22:35 【问题描述】:

我在 Torch 中有以下 lstm 模型实现,我从这里获取: https://github.com/wojzaremba/lstm/blob/master/main.lua

我对以下代码有疑问:

local function create_network()
  local x                = nn.Identity()()
  local y                = nn.Identity()()
  local prev_s           = nn.Identity()()
  local i                = [0] = LookupTable(params.vocab_size,
                                                    params.rnn_size)(x)
  local next_s           = 
  local split         = prev_s:split(2 * params.layers)
  for layer_idx = 1, params.layers do
    local prev_c         = split[2 * layer_idx - 1]
    local prev_h         = split[2 * layer_idx]
    local dropped        = nn.Dropout(params.dropout)(i[layer_idx - 1])
    local next_c, next_h = lstm(dropped, prev_c, prev_h)
    table.insert(next_s, next_c)
    table.insert(next_s, next_h)
    i[layer_idx] = next_h
  end
  local h2y              = nn.Linear(params.rnn_size, params.vocab_size)
  local dropped          = nn.Dropout(params.dropout)(i[params.layers])
  local pred             = nn.LogSoftMax()(h2y(dropped))
  local err              = nn.ClassNLLCriterion()(pred, y)
  local module           = nn.gModule(x, y, prev_s,
                                      err, nn.Identity()(next_s))
  module:getParameters():uniform(-params.init_weight, params.init_weight)
  return transfer_data(module)
end

在lstm输入的嵌入部分,代码在处理ptb数据库时使用了LookupTable层,现在我想知道如何使用LookupTable定义其他嵌入到不同类型的数据。特别是,输入是 RGB 图像,嵌入将是 CNN 模型之一,例如没有完全连接层的 AlexNet。 (https://gist.github.com/gcr/0bab9929dfee95164a4d)

它看起来对我来说太模糊了。 为此目的有更好的设计吗? 如何在 CNN 模型之上创建 LSTM?

【问题讨论】:

【参考方案1】:

Torch nn.LookupTable 只是 doing 张量 index 在其权重张量上。在您指定的代码中,它还用于学习词向量,因为它包含在 nngraph 模型中。如果你有一个预训练的模型,你可以将它的权重设置为 LookupTable,但是这一次,你不应该将它包含在 nngraph 中。权重张量的维度应该是 nIndex(例如,您有多少不同的图像)x nOutput(例如,LSTM 隐藏大小 - 代码中的 rnn_size)。或者,根本不用 LookupTable,直接指定输入 Tensor 即可。

【讨论】:

以上是关于CNN之上的LSTM的主要内容,如果未能解决你的问题,请参考以下文章

Fast R-CNN(RoI)简介

卷积神经网络(CNN)基础介绍

第17篇TextCNN

LSTM介绍LSTM变种常用架构以及相关文献梳理

DL-4长短期记忆网络(LSTM)

基于区域的CNN(R-CNN)