增强学习Recurrent Visual Attention源码解读
Posted shenxiaolu1984
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了增强学习Recurrent Visual Attention源码解读相关的知识,希望对你有一定的参考价值。
Mnih, Volodymyr, Nicolas Heess, and Alex Graves. “Recurrent models of visual attention.” Advances in Neural Information Processing Systems. 2014.
这篇文章处理的任务非常简单:MNIST手写数字分类。但使用了聚焦机制(Visual Attention),不是一次看一张大图进行估计,而是分多次观察小部分图像,根据每次查看结果移动观察位置,最后估计结果。
Yoshua Bengio的高徒,先后供职于LISA和Element Research的Nicolas Leonard用Torch实现了这篇文章的算法。Torch官方cheetsheet的demo中,就包含这篇源码,作者自己的讲解也刊登在Torch的博客中,足见其重要性。
通过这篇源码,我们可以
- 理解聚焦机制中较简单的hard attention
- 了解增强学习的基本流程
- 复习Torch和扩展包dp的相关语法
本文解读训练源码,分三大部分:参数设置,网络构造,训练设置。以下逐次介绍其中重要的语句。
参数设置
除了Torch之外,还需要包含Nicholas Leonard自己编写的两个包。dp:能够简化DL流程,训练过程更“面向对象”;rnn:实现Recurrent网络。
require 'dp'
require 'rnn'
首先使用Torch的CmdLine
类设定一系列参数,存储在opt
中。这是Torch的标准写法。
cmd = torch.CmdLine()
cmd:option('--learningRate', 0.01, 'learning rate at t=0') -- 参数名,参数值,说明
local opt = cmd:parse(arg or {}) --把cmd中的参数传入opt
把数据载入到数据集ds
中,数据是dp包中已经下载好的:
ds = dp[opt.dataset]()
网络构造
这篇源码中模型的写法遵循:由底到顶,先细节后整体。和CNN不同,Recurrent网络带有反馈,呈现较为复杂的多级嵌套结构。请着重关注每个模块的输入、输出和作用部分。
Glimpse网络
输入:图像
I
I
I和观察位置
l
l
l
输出:观察结果
x
x
x
蓝色输入,橙色输出,菱形表示串接:
首先用locationSensor
(左半)提取位置信息
l
l
l中的特征:
locationSensor:add(nn.SelectTable(2)) --选择两个输入中的第二个,位置l
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize)) --Torch中的Linear指全连层
locationSensor:add(nn[opt.transfer]()) --opt.transfer定义一种非线性运算,本文中是ReLU
之后用glimpseSensor
(右半)提取图像
I
I
I位置
l
l
l的特征。
其中SpacialGlimpse是dp中定义的层,提取尺寸为PatchSize的Depth层图像,相邻层比例为Scale。
glimpseSensor:add(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float()) --SpatialGlimpse提取小块金字塔
glimpseSensor:add(nn.Collapse(3)) --压缩第三维
glimpseSensor:add(nn.Linear(ds:imageSize('c')*(opt.glimpsePatchSize^2)*opt.glimpseDepth, opt.glimpseHiddenSize))
glimpseSensor:add(nn[opt.transfer]())
两者结果串接为glimpse
,输出包含位置和纹理信息的
x
x
x,尺寸为hiddenSize:
glimpse:add(nn.ConcatTable():add(locationSensor):add(glimpseSensor))
glimpse:add(nn.JoinTable(1,1)) --把串接数据合并成一个Tensor
glimpse:add(nn.Linear(opt.glimpseHiddenSize+opt.locatorHiddenSize, opt.imageHiddenSize))
glimpse:add(nn[opt.transfer]())
glimpse:add(nn.Linear(opt.imageHiddenSize, opt.hiddenSize)) --从imageHiddenSize到hiddenSize的全连层
作用:通过小范围观测,提取纹理和位置信息。
说明
Torch的基础数据是Tensor,而lua中用Table实现类似数组的功能。nn库中专门有一系列Table层,用于处理涉及这两者的运算。例如:
ConcatTable
- 把若干个输出Tensor放置在一个Table中。
SelectTable
- 从输入的Table中选择一个Tensor。
JoinTable
- 把输入Table中的所有Tensor合并成一个Tensor。
Recurrent网络
输入:和Glimpse网络相同,图像
I
I
I,观察位置
l
l
l。
输出:系统循环状态
r
r
r
使用Recurrent类创建一个包含Glimpse子网络的rnn
框架。Recurrent类的第二个参数(glimpse)指出如何处理输入,第三个参数(recurrent)指出如何处理前一时刻的循环状态。
recurrent = nn.Linear(opt.hiddenSize, opt.hiddenSize)
rnn = nn.Recurrent(opt.hiddenSize, glimpse, recurrent, nn[opt.transfer](), 99999)
作用:通过小范围观测,更新网络循环状态。
nn.Recurrent最后一个参数表示“最多考虑的backward步数”,设定为一个很大的值(99999)。在后续模块中会设定真实的记忆步数rho。
Locator网络
输入:系统循环状态
r
r
r,也就是Recurrent网络的输出
输出:观测位置
l
l
l
这部分核心是dp库中的ReinforceNormal
层:正态分布的强化学习层。dp库中还有其他分布的强化学习层。
locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd, opt.stochastic)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h"))) --对位置l做了归一化:相对图像中心的最大偏移为unitPixel。
ReinforceNormal
层在训练状态下,会以前一层输入为均值,以第一个参数(2*opt.locatorStd)为方差,产生符合高斯分布采样结果;
在训练状态下,如果第二个参数(opt.stochastic)为真,则以相同方式采样,否则直接传递前一层结果。
简单来说,Reinforce层的作用是:在训练时,围绕当前策略(前层输出),探索一些新策略(高斯采样)。具体怎么训练在下篇再说。
作用:利用系统循环状态,决定观测位置。
Attention网络
输入:图像
I
I
I
输出:系统循环状态
r
r
r
直接使用rnn包中的RecurrentAttention层进行定义。
第一个参数(rnn)指明如何处理循环状态
r
r
r的记忆,第二个参数(locator)指明利用循环状态执行何种动作(action)。第三个参数(rho)指明循环步数,第四个参数指明隐变量维度。
attention = nn.RecurrentAttention(rnn, locator, opt.rho, {opt.hiddenSize})
作用:输入图像,循环固定步数,每一步更新系统循环状态。
Agent网络
输入:图像
I
I
I
输出:字符属于各类的概率向量
p
p
p
在前面attention
网络的基础上,只对系统循环变量做简单非线性变换,即得到图像属于各类字符的概率
p
p
p。
agent:add(attention)
agent:add(nn.SelectTable(-1))
agent:add(nn.Linear(opt.hiddenSize, #ds:classes()))
agent:add(nn.LogSoftMax()) -- 这里输出分类结果
由于系统中存在强化学习层ReinforceNormal
,所以需要一个baseline变量
b
b
b。这里利用ConcatTable
把
b
b
b和分类结果合并到一个Table里输出。
seq:add(nn.Constant(1,1))
seq:add(nn.Add(1))
concat = nn.ConcatTable():add(nn.Identity()):add(seq)
concat2 = nn.ConcatTable():add(nn.Identity()):add(concat)
agent:add(concat2)
整个系有两组输出:分类结果 p p p,以及分类结果+baseline对 { p , b } \\{p,b\\} {p,b}。
作用:把系统隐变量转化成估计结果,并且输出一个baseline,便于后续优化。
训练设置
在dp库中,训练过程是分层定义的,为了说明清晰,倒序讲解。
首先(在代码里是最后),定义实验xp
,使用的模型就是前述网络agent
:
xp = dp.Experiment{
model = agent, -- nn.Sequential, 待优化模型
optimizer = train, -- dp.Optimizer,训练
validator = valid, -- dp.Evaluator,验证
tester = tester, -- dp.Evaluator,测试
observer = { -- 设定log
ad,
dp.FileLogger(),
dp.EarlyStopper{
max_epochs = opt.maxTries,
error_report={'validator','feedback','confusion','accuracy'},
maximize = true
}
},
random_seed = os.time(),
max_epoch = opt.maxEpoch -- 最大迭代次数
}
训练
train
是一个dp.Optimizer
类型对象,这个类继承自抽象类dp.propogator
,需要指明6个参数:
train = dp.Optimizer{
loss=..., epoch_callback=..., callback = ..., feedback - ...,sampler = ..., progress = ...
}
loss
定义了损失层。用ParallelCriterion
把监督学习的ClassNLLCriterion
和增强学习的VRClassReward
并列优化。
loss = nn.ParallelCriterion(true)
:add(nn.ModuleCriterion(nn.ClassNLLCriterion(), nil,nn.Convert())) -- 监督学习:negative log-likelihood
:add(nn.ModuleCriterion(nn.VRClassReward(agent, opt.rewardScale), nil, nn.Convert())) -- 增强学习:得分最高类与标定相同反馈1,否则反馈-1
epoch_callback
函数设定每个epoch结束时执行的动作,一般用来调整opt
中的学习率。
epoch_callback = function(model, report) -- called every epoch
if report.epoch > 0 then
opt.learningRate = opt.learningRate + opt.decayFactor
opt.learningRate = math.max(opt.minLR, opt.learningRate)
if not opt.silent then
print("learningRate", opt.learningRate)
end
end
end
callback
是核心函数,更新模型参数:
callback = function(model, report)
if opt.cutoffNorm > 0 then
local norm = model:gradParamClip(opt.cutoffNorm) -- dpnn扩展,约束梯度,有益于RNN
opt.meanNorm = opt.meanNorm and (opt.meanNorm*0.9 + norm*0.1) or norm;
if opt.lastEpoch < report.epoch and not opt.silent then
print("mean gradParam norm", opt.meanNorm)
end
end
model:updateGradParameters(opt.momentum) -- dpnn扩展,根据momentum更新梯度
model:updateParameters(opt.learningRate) -- 根据学习率更新参数
model:maxParamNorm(opt.maxOutNorm) -- dpnn扩展,约束参数范围
model:zeroGradParameters() -- 梯度置零
end
feedback
提供I/O用来生成报告,这里输出分类结果与真值比较的confusion matrix。回忆一下:网络的输出是
{
p
,
{
p
,
b
}
}
\\{p,\\{p,b\\}\\}
{p,{p,b}},所以真正的输出用SelectTable(1)
获得。
feedback = dp.Confusion{output_module=nn.SelectTable(1)}
sampler
决定如何从训练集中采样:设定epoch和batch大小。
sampler = dp.ShuffleSampler{
epoch_size = opt.trainEpochSize, batch_size = opt.batchSize
}
progress
是个布尔型,控制是否显示进度条。
progress = opt.progress
验证与测试
valid
是一个dp.Evaluator
类成员变量,同样继承自dp.propogator
。只需要指明feedback
,sampler
,progress
这三个参数即可。
valid = dp.Evaluator{
feedback = dp.Confusion{output_module=nn.SelectTable(1)},
sampler = dp.Sampler{epoch_size = opt.validEpochSize, batch_size = opt.batchSize},
progress = opt.progress
}
test
和valid
类似,连进度条都不用打了
tester = dp.Evaluator{
feedback = dp.Confusion{output_module=nn.SelectTable(1)},
sampler = dp.Sampler{batch_size = opt.batchSize}
}
执行
在这一步,把已经读取好的数据集ds
输入到实验xp
中去:
xp:run(ds)
以上是关于增强学习Recurrent Visual Attention源码解读的主要内容,如果未能解决你的问题,请参考以下文章
Re3 : Real-Time Recurrent Regression Networks for Visual Tracking of Generic Objects
机器学习笔记:Dilated Recurrent Neural Networks
Visual Question Answering with Memory
带有RNN循环神经网络的机器学习 4 NLP 从零到英雄 ML with Recurrent Neural Networks